Exemplo n.º 1
0
 def __init__(self, optimize=True):
     # must be before Module.init since the field is used in __getattr__
     Module.__init__(self)
     self._set_optimized(optimize)
     self._parameters = OrderedParameterDict(self)
     self._buffers = OrderedBufferDict(self)
     self._modules = OrderedModuleDict(self)
Exemplo n.º 2
0
 def __getattr__(self, attr):
     if self._has_method(attr):
         if attr in self.__class__._original_methods:
             original_method = self.__class__._original_methods[attr]
             script_method = self._get_method(attr)
             return functools.wraps(original_method)(script_method)
         else:
             return self._get_method(attr)
     return Module.__getattr__(self, attr)
Exemplo n.º 3
0
def set_params_with_array(
    module: Module, x: np.ndarray, property_dict: Dict[str, TorchAttr]
) -> Module:
    r"""Set module parameters with values from numpy array.

    Args:
        module: Module with parameters to be set
        x: Numpy array with parameter values
        property_dict: Dictionary of parameter names and torch attributes as
            returned by module_to_array.

    Returns:
        Module: module with parameters updated in-place.

    Example:
        >>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
        >>> parameter_array, property_dict, bounds_out = module_to_array(mll)
        >>> parameter_array += 0.1  # perturb parameters (for example only)
        >>> mll = set_params_with_array(mll, parameter_array,  property_dict)
    """
    param_dict = OrderedDict(module.named_parameters())
    start_idx = 0
    for p_name, attrs in property_dict.items():
        # Construct the new tensor
        if len(attrs.shape) == 0:  # deal with scalar tensors
            end_idx = start_idx + 1
            new_data = torch.tensor(
                x[start_idx], dtype=attrs.dtype, device=attrs.device
            )
        else:
            end_idx = start_idx + np.prod(attrs.shape)
            new_data = torch.tensor(
                x[start_idx:end_idx], dtype=attrs.dtype, device=attrs.device
            ).view(*attrs.shape)
        start_idx = end_idx
        # Update corresponding parameter in-place. Disable autograd to update.
        param_dict[p_name].requires_grad_(False)
        param_dict[p_name].copy_(new_data)
        param_dict[p_name].requires_grad_(True)
    return module
Exemplo n.º 4
0
def count_params(model: nn.Module) -> int:
    """
    Count the number of parameters in a model.
    """
    assert isinstance(model, nn.Module)
    return sum((parameter.nelement() for parameter in model.parameters()))
Exemplo n.º 5
0
def init_weight(m: nn.Module):
    for name, param in m.named_parameters():
        if 'bias' in name:
            continue
        nn.init.kaiming_normal_(param.data)
Exemplo n.º 6
0
 def _update_target_func(_target_func: nn.Module, _func: nn.Module):
     if _target_func is not None:
         assert _func is not None
         _target_func.load_state_dict(_func.state_dict())
Exemplo n.º 7
0
def val_sanity_fit(model: nn.Module,
                   val_loader,
                   criterion,
                   device,
                   num_batches: int = None,
                   log_interval: int = 100):
    """
    Performs Sanity fit over valid loader.
    Use this to dummy check your val_step function. It does not calculate metrics, timing, or does checkpointing.
    It iterates over both train_loader and val_loader for given batches.
    Note: - It does not to loss.backward().
    Args:
        model : A PyTorch Detr Model.
        val_loader : Validation loader.
        criterion : Loss function to be optimized.
        device : "cuda" or "cpu"
        num_batches : (optional) Integer To limit sanity fit over certain batches.
                                 Useful is data is too big even for sanity check.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
    """

    model = model.to(device)
    criterion = criterion.to(device)
    train_sanity_start = time.time()
    model.eval()

    last_idx = len(val_loader) - 1
    criterion.eval()
    cnt = 0

    for batch_idx, (inputs, targets) in enumerate(val_loader):
        last_batch = batch_idx == last_idx
        images = list(image.to(device) for image in inputs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        outputs = model(images)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                   if k in weight_dict)

        cnt += 1
        if last_batch or batch_idx % log_interval == 0:
            print(
                f"Train sanity check passed for batch till {batch_idx} batches"
            )

        if num_batches is not None:
            if cnt >= num_batches:
                print(f"Done till {num_batches} train batches")
                print("All specified batches done")
                train_sanity_end = time.time()
                print(
                    f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}"
                )
                return True

    train_sanity_end = time.time()

    print("All specified batches done")
    print(
        f"Train sanity fit check passed in time {train_sanity_end-train_sanity_start}"
    )

    return True
Exemplo n.º 8
0
def train_inner(train_data: List[Tuple[List[int], int]],
                valid_data: List[Tuple[List[int], int]],
                model: Module,
                num_classes: int,
                epochs: int,
                evaluation_period: int,
                only_epoch_eval: bool,
                model_log_directory: str,
                learning_rate: float,
                batch_size: int,
                disable_scheduler: bool = False,
                scheduler_patience: int = 10,
                scheduler_factor: float = 0.1,
                gpu_device: Optional[torch.device] = None,
                clip_threshold: Optional[float] = None,
                max_doc_len: Optional[int] = None,
                word_dropout: float = 0,
                patience: int = 30,
                resume_training: bool = False,
                disable_tqdm: bool = False,
                tqdm_update_period: int = 1) -> None:
    # create signal handlers in case script receives termination signals
    # adapted from: https://stackoverflow.com/a/31709094
    for specific_signal in [
            signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT
    ]:
        signal.signal(
            specific_signal,
            partial(signal_handler,
                    os.path.join(model_log_directory, "exit_code")))

    # initialize general local variables
    updates_per_epoch = ceil(len(train_data) / batch_size)
    patience_reached = False

    # load model checkpoint if training is being resumed
    if resume_training and len(
            glob(os.path.join(model_log_directory, "*last*.pt"))) > 0:
        model_checkpoint = torch.load(glob(
            os.path.join(model_log_directory, "*last*.pt"))[0],
                                      map_location=torch.device("cpu"))
        model.load_state_dict(
            model_checkpoint["model_state_dict"])  # type: ignore
        if (model_checkpoint["update"] +  # type: ignore
                1) == updates_per_epoch:  # type: ignore
            current_epoch: int = model_checkpoint["epoch"] + 1  # type: ignore
            current_update: int = 0
        else:
            current_epoch: int = model_checkpoint["epoch"]  # type: ignore
            current_update: int = model_checkpoint["update"] + 1  # type: ignore
        best_valid_loss: float = model_checkpoint[  # type: ignore
            "best_valid_loss"]  # type: ignore
        best_valid_loss_index: int = model_checkpoint[  # type: ignore
            "best_valid_loss_index"]  # type: ignore
        best_valid_acc: float = model_checkpoint[  # type: ignore
            "best_valid_acc"]  # type: ignore

        # check for edge-case failures
        if current_epoch >= epochs:
            # log information at the end of training
            LOGGER.info("%s training epoch(s) previously completed, exiting" %
                        epochs)
            # save exit-code and final processes
            save_exit_code(os.path.join(model_log_directory, "exit_code"),
                           FINISHED_EPOCHS)
            return None
        elif best_valid_loss_index >= patience:
            LOGGER.info("Patience threshold previously reached, exiting")
            # save exit-code and final processes
            save_exit_code(os.path.join(model_log_directory, "exit_code"),
                           PATIENCE_REACHED)
            return None
    else:
        resume_training = False
        current_epoch = 0
        current_update = 0
        best_valid_loss_index = 0
        best_valid_loss = float("inf")
        best_valid_acc = float("-inf")

    # send model to correct device
    if gpu_device is not None:
        LOGGER.info("Transferring model to GPU device: %s" % gpu_device)
        model.to(gpu_device)

    # instantiate Adam optimizer
    LOGGER.info("Initializing Adam optimizer with LR: %s" % learning_rate)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    # load optimizer state dictionary
    if resume_training:
        optimizer.load_state_dict(
            model_checkpoint["optimizer_state_dict"])  # type: ignore

    # instantiate negative log-likelihood loss which is summed over batch
    LOGGER.info("Using NLLLoss with sum reduction")
    loss_function = NLLLoss(weight=None, reduction="sum")

    # enable gradient clipping in-place if provided
    if clip_threshold is not None and clip_threshold > 0:
        LOGGER.info("Enabling gradient clipping with threshold: %s" %
                    clip_threshold)
        enable_gradient_clipping(model, clip_threshold)

    # initialize learning rate scheduler if relevant
    if not disable_scheduler:
        LOGGER.info(("Initializing learning rate scheduler with "
                     "factor=%s and patience=%s") %
                    (scheduler_factor, scheduler_patience))
        scheduler: Optional[ReduceLROnPlateau]
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=scheduler_factor,
                                      patience=scheduler_patience,
                                      verbose=True)
        if resume_training:
            scheduler.load_state_dict(
                model_checkpoint["scheduler_state_dict"])  # type: ignore
    else:
        scheduler = None

    # initialize tensorboard writer if provided
    LOGGER.info("Initializing tensorboard writer in directory: %s" %
                os.path.join(model_log_directory, "events"))
    writer = SummaryWriter(os.path.join(model_log_directory, "events"))

    # set numpy and torch RNG back to previous states before training
    if resume_training:
        if current_update == 0:
            np.random.set_state(
                model_checkpoint["numpy_last_random_state"])  # type: ignore
        else:
            np.random.set_state(
                model_checkpoint["numpy_epoch_random_state"])  # type: ignore
        torch.random.set_rng_state(
            model_checkpoint["torch_last_random_state"])  # type: ignore

    # loop over epochs
    for epoch in range(current_epoch, epochs):
        # set model on train mode and enable autograd
        model.train()
        torch.autograd.set_grad_enabled(True)

        # initialize loop variables
        if resume_training and epoch == current_epoch and current_update != 0:
            train_loss: Union[float,
                              torch.Tensor] = model_checkpoint[  # type: ignore
                                  "train_loss"]  # type: ignore
            samples_seen: int = model_checkpoint[  # type: ignore
                "samples_seen"]  # type: ignore
        else:
            train_loss = 0.
            samples_seen = 0

        # cache numpy random state for model checkpoint
        numpy_epoch_random_state = np.random.get_state()

        # main training loop
        LOGGER.info("Training SoPa++ model")
        with tqdm(shuffled_chunked_sorted(train_data, batch_size),
                  position=0,
                  mininterval=0.05,
                  disable=disable_tqdm,
                  unit="batch",
                  desc="Training [Epoch %s/%s]" %
                  (epoch + 1, epochs)) as train_tqdm_batches:
            # loop over train batches
            for update, batch in enumerate(train_tqdm_batches):
                # return to previous update and random state, if relevant
                if (resume_training and epoch == current_epoch
                        and current_update != 0):
                    if update < current_update:
                        continue
                    elif update == current_update:
                        np.random.set_state(model_checkpoint[  # type: ignore
                            "numpy_last_random_state"])  # type: ignore

                # create batch object and parse out gold labels
                batch, gold = Batch(
                    [x[0] for x in batch],
                    model.embeddings,  # type: ignore
                    to_cuda(gpu_device),
                    word_dropout,
                    max_doc_len), [x[1] for x in batch]

                # find aggregate loss across samples in batch
                train_batch_loss = train_batch(model, batch, num_classes, gold,
                                               optimizer, loss_function,
                                               gpu_device)

                # add batch loss to train_loss
                train_loss += train_batch_loss  # type: ignore

                # increment samples seen
                samples_seen += batch.size()

                # update tqdm progress bar
                if (update + 1) % tqdm_update_period == 0 or (
                        update + 1) == len(train_tqdm_batches):
                    train_tqdm_batches.set_postfix(
                        batch_loss=train_batch_loss.item() / batch.size())

                # start evaluation routine
                if (not only_epoch_eval and (update + 1) % evaluation_period
                        == 0) or (update + 1) == len(train_tqdm_batches):
                    # update tqdm batches counter
                    train_tqdm_batches.update()

                    # set valid loss to zero
                    update_number = (epoch * updates_per_epoch) + (update + 1)
                    valid_loss: Union[float, torch.Tensor] = 0.

                    # set model on eval mode and disable autograd
                    model.eval()
                    torch.autograd.set_grad_enabled(False)

                    # compute mean train loss over updates and accuracy
                    # NOTE: mean_train_loss contains stochastic noise
                    LOGGER.info("Evaluating SoPa++ on training set")
                    train_loss = cast(torch.Tensor, train_loss)
                    mean_train_loss = train_loss.item() / samples_seen
                    train_acc = evaluate_metric(model, train_data, batch_size,
                                                gpu_device, accuracy_score,
                                                max_doc_len)

                    # add training loss data
                    writer.add_scalar("loss/train_loss", mean_train_loss,
                                      update_number)
                    writer.add_scalar("accuracy/train_accuracy", train_acc,
                                      update_number)

                    # add named parameter data
                    for name, param in model.named_parameters():
                        writer.add_scalar("parameter_mean/" + name,
                                          param.detach().mean(), update_number)
                        writer.add_scalar("parameter_std/" + name,
                                          param.detach().std(), update_number)
                        if param.grad is not None:
                            writer.add_scalar("gradient_mean/" + name,
                                              param.grad.detach().mean(),
                                              update_number)
                            writer.add_scalar("gradient_std/" + name,
                                              param.grad.detach().std(),
                                              update_number)

                    # loop over static valid set
                    LOGGER.info("Evaluating SoPa++ on validation set")
                    with tqdm(chunked_sorted(valid_data, batch_size),
                              position=0,
                              mininterval=0.05,
                              disable=disable_tqdm,
                              unit="batch",
                              desc="Validating [Epoch %s/%s] [Batch %s/%s]" %
                              (epoch + 1, epochs, update + 1,
                               updates_per_epoch)) as valid_tqdm_batches:
                        for valid_update, batch in enumerate(
                                valid_tqdm_batches):
                            # create batch object and parse out gold labels
                            batch, gold = Batch(
                                [x[0] for x in batch],
                                model.embeddings,  # type: ignore
                                to_cuda(gpu_device),
                                0.,
                                max_doc_len), [x[1] for x in batch]

                            # find aggregate loss across valid samples in batch
                            valid_batch_loss = compute_loss(
                                model, batch, num_classes, gold, loss_function,
                                gpu_device)

                            # add batch loss to valid_loss
                            valid_loss += valid_batch_loss  # type: ignore

                            if (valid_update +
                                    1) % tqdm_update_period == 0 or (
                                        valid_update +
                                        1) == len(valid_tqdm_batches):
                                valid_tqdm_batches.set_postfix(
                                    batch_loss=valid_batch_loss.item() /
                                    batch.size())

                    # compute mean valid loss and accuracy
                    valid_loss = cast(torch.Tensor, valid_loss)
                    mean_valid_loss = valid_loss.item() / len(valid_data)
                    valid_acc = evaluate_metric(model, valid_data, batch_size,
                                                gpu_device, accuracy_score,
                                                max_doc_len)

                    # set model on train mode and enable autograd
                    model.train()
                    torch.autograd.set_grad_enabled(True)

                    # add valid loss data to tensorboard
                    writer.add_scalar("loss/valid_loss", mean_valid_loss,
                                      update_number)
                    writer.add_scalar("accuracy/valid_accuracy", valid_acc,
                                      update_number)

                    # log out report of current evaluation state
                    LOGGER.info("Epoch: {}/{}, Batch: {}/{}".format(
                        epoch + 1, epochs, (update + 1), updates_per_epoch))
                    LOGGER.info("Mean training loss: {:.3f}, "
                                "Training accuracy: {:.3f}%".format(
                                    mean_train_loss, train_acc * 100))
                    LOGGER.info("Mean validation loss: {:.3f}, "
                                "Validation accuracy: {:.3f}%".format(
                                    mean_valid_loss, valid_acc * 100))

                    # apply learning rate scheduler after evaluation
                    if scheduler is not None:
                        scheduler.step(valid_loss)

                    # check for loss improvement and save model if necessary
                    # optionally increment patience counter or stop training
                    # NOTE: loss values are summed over all data (not mean)
                    if valid_loss.item() < best_valid_loss:
                        # log information and update records
                        LOGGER.info("New best validation loss")
                        if valid_acc > best_valid_acc:
                            best_valid_acc = valid_acc
                            LOGGER.info("New best validation accuracy")

                        # update patience related diagnostics
                        best_valid_loss = valid_loss.item()
                        best_valid_loss_index = 0
                        LOGGER.info("Patience counter: %s/%s" %
                                    (best_valid_loss_index, patience))

                        # find previous best checkpoint(s)
                        legacy_checkpoints = glob(
                            os.path.join(model_log_directory, "*_best_*.pt"))

                        # save new best checkpoint
                        model_save_file = os.path.join(
                            model_log_directory,
                            "spp_checkpoint_best_{}_{}.pt".format(
                                epoch, (update + 1)))
                        LOGGER.info("Saving best checkpoint: %s" %
                                    model_save_file)
                        save_checkpoint(epoch, update, samples_seen, model,
                                        optimizer, scheduler,
                                        numpy_epoch_random_state,
                                        train_loss.item(), best_valid_loss,
                                        best_valid_loss_index, best_valid_acc,
                                        model_save_file)

                        # delete previous best checkpoint(s)
                        for legacy_checkpoint in legacy_checkpoints:
                            os.remove(legacy_checkpoint)
                    else:
                        # update patience related diagnostics
                        best_valid_loss_index += 1
                        LOGGER.info("Patience counter: %s/%s" %
                                    (best_valid_loss_index, patience))

                        # create hook to exit training if patience reached
                        if best_valid_loss_index == patience:
                            patience_reached = True

                    # find previous last checkpoint(s)
                    legacy_checkpoints = glob(
                        os.path.join(model_log_directory, "*_last_*.pt"))

                    # save latest checkpoint
                    model_save_file = os.path.join(
                        model_log_directory,
                        "spp_checkpoint_last_{}_{}.pt".format(
                            epoch, (update + 1)))

                    LOGGER.info("Saving last checkpoint: %s" % model_save_file)
                    save_checkpoint(epoch, update, samples_seen, model,
                                    optimizer,
                                    scheduler, numpy_epoch_random_state,
                                    train_loss.item(), best_valid_loss,
                                    best_valid_loss_index, best_valid_acc,
                                    model_save_file)

                    # delete previous last checkpoint(s)
                    for legacy_checkpoint in legacy_checkpoints:
                        os.remove(legacy_checkpoint)

                    # hook to stop training in case patience was reached
                    # if it was reached strictly before last epoch and update
                    if patience_reached:
                        if not (epoch == max(range(epochs)) and
                                (update + 1) == len(train_tqdm_batches)):
                            LOGGER.info("Patience threshold reached, "
                                        "stopping training")
                            # save exit-code and final processes
                            save_exit_code(
                                os.path.join(model_log_directory, "exit_code"),
                                PATIENCE_REACHED)
                            return None

    # log information at the end of training
    LOGGER.info("%s training epoch(s) completed, stopping training" % epochs)

    # save exit-code and final processes
    save_exit_code(os.path.join(model_log_directory, "exit_code"),
                   FINISHED_EPOCHS)
Exemplo n.º 9
0
def train_on_loader(model: nn.Module,
                    train_gen: DataLoader,
                    val_gen: Optional[DataLoader],
                    loss_fn: Any,
                    optimizer: Optimizer,
                    n_epochs: int,
                    batch_first: bool = False,
                    device: Optional[torch.device] = torch.device('cpu'),
                    callbacks: Optional[List[Callback]] = None,
                    before_step=None,
                    verbosity: int = 2) -> ModelHistory:
    """Trains a model using data from a DataLoader.

    # Arguments
        model: The PyTorch model.
        train_gen: A DataLoader containing the training data.
        val_gen: A DataLoader containing the validation data.
        loss_fn: The loss function from which gradients are computed.
            Its expected signature is `loss_fn(model_output, y_true)`.
        optimizer: The optimizer used in the backpropagation step.
        n_epochs: How many passes should be performed over the train_gen.
        batch_first: For sequential data, if True data is expected to have the layout
             `[seq_len, batch_size, *]`, otherwise `[batch_size, seq_len, *]`.
        device:
        callbacks: List of utility callbacks to help training the model.
        verbosity: 0: silent, 1:show epoch progress bar, 2: show batch progress bar.
    # Return
        A ModelHistory object representing the model training history.
    """

    callbacks_container = CallbacksContainer(callbacks or [])
    batch_index = 0 if batch_first else 1

    model_history = ModelHistory(model)

    epoch_iterator = range(1, n_epochs + 1)
    if verbosity == 1:
        epoch_iterator = tqdm.tqdm(epoch_iterator, desc='Epoch')
    elif verbosity == 2:
        callbacks_container.append(ProgressBar(len(train_gen), n_epochs))

    for epoch in epoch_iterator:
        model.train()
        callbacks_container.on_epoch_begin(epoch, model_history)

        epoch_loss = 0
        seen_samples = 0
        training_metrics = defaultdict(int)

        for batch_id, batch_data in enumerate(train_gen):
            callbacks_container.on_batch_begin(batch_id, model_history)

            # even if batch_data = [x, y], batch_features = [x] and batch_y = [y]
            batch_features: list = batch_data[:-1]
            batch_labels = batch_data[-1]

            batch_features = [
                _move_to_device(ft, device) for ft in batch_features
            ]
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            output = model(*batch_features)
            loss = loss_fn(output, batch_labels)
            loss.backward()

            if before_step:
                before_step(model, loss, optimizer)

            optimizer.step()

            # All feature matrices should have the same amount of sample entries,
            # hence we can take any of them to figure out the batch size
            n_samples = batch_features[0].size(batch_index)

            seen_samples += n_samples
            epoch_loss += loss.item()

            # Accumulating metrics and losses for the current epoch
            batch_metrics = model.metric(output, batch_labels)
            for m_name, m_value in batch_metrics.items():
                training_metrics[m_name] += m_value
            training_metrics['loss'] = epoch_loss / (batch_id + 1)

            # Normalizing metrics up to the current batch to display in the progress bar
            model_history.append_batch_data(
                _normalize_metrics(training_metrics, seen_samples))

            callbacks_container.on_batch_end(batch_id, model_history)

        model_history.append_trn_logs(
            _normalize_metrics(training_metrics, seen_samples))

        if val_gen:
            val_logs = evaluate_on_loader(model,
                                          val_gen,
                                          loss_fn,
                                          batch_first,
                                          device,
                                          verbosity=0)

            # Adding the val_ prefix and storing metrics over the entire validation data
            val_logs = {
                'val_' + m_name: m_value
                for m_name, m_value in val_logs.items()
            }
            model_history.append_dev_logs(val_logs)

        callbacks_container.on_epoch_end(epoch, model_history)
        if model_history.should_stop_training():
            break

    model_history.close(n_epochs)
    callbacks_container.on_train_end()

    return model_history
Exemplo n.º 10
0
def val_step(model: nn.Module,
             val_loader,
             criterion,
             device,
             num_batches: int = None,
             log_interval: int = 100):
    """
    Performs one step of validation. Calculates loss, forward pass and returns metrics.
    Args:
        model : PyTorch Detr Model.
        val_loader : Validation loader.
        criterion : Detr Loss function to be optimized.
        device : "cuda" or "cpu"
        num_batches : (optional) Integer To limit validation to certain number of batches.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
    """

    model = model.to(device)
    start_val_step = time.time()
    last_idx = len(val_loader) - 1
    batch_time_m = utils.AverageMeter()
    cnt = 0
    model.eval()
    criterion.eval()
    batch_start = time.time()
    metrics = OrderedDict()

    total_loss = utils.AverageMeter()
    bbox_loss = utils.AverageMeter()
    giou_loss = utils.AverageMeter()
    labels_loss = utils.AverageMeter()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            last_batch = batch_idx == last_idx
            images = list(image.to(device) for image in inputs)
            targets = [{k: v.to(device)
                        for k, v in t.items()} for t in targets]

            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                       if k in weight_dict)

            cnt += 1
            total_loss.update(loss.item())
            bbox_loss.update(loss_dict["loss_bbox"].item())
            giou_loss.update(loss_dict["loss_giou"].item())
            labels_loss.update(loss_dict["loss_ce"].item())

            batch_time_m.update(time.time() - batch_start)
            batch_start = time.time()

            if last_batch or batch_idx % log_interval == 0:  # If we reach the log intervel
                print(
                    "Batch Validation Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  "
                    .format(batch_time=batch_time_m, ))

            if num_batches is not None:
                if cnt >= num_batches:
                    end_val_step = time.time()
                    metrics["total_loss"] = total_loss.avg
                    metrics["bbox_loss"] = bbox_loss.avg
                    metrics["giou_loss"] = giou_loss.avg
                    metrics["labels_loss"] = labels_loss.avg
                    print(f"Done till {num_batches} Validation batches")
                    print(
                        f"Time taken for validation step = {end_val_step - start_val_step} sec"
                    )
                    return metrics

    end_val_step = time.time()
    metrics["total_loss"] = total_loss.avg
    metrics["bbox_loss"] = bbox_loss.avg
    metrics["giou_loss"] = giou_loss.avg
    metrics["labels_loss"] = labels_loss.avg
    print(
        f"Time taken for validation step = {end_val_step - start_val_step} sec"
    )
    return metrics
Exemplo n.º 11
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )
            if args.adversarial:
                debug(
                    f'D Loss = {d_loss_avg:.4e}, G Loss = {g_loss_avg:.4e}, GP Norm = {gp_norm_avg:.4}'
                )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
Exemplo n.º 12
0
def infer_model_device(model: nn.Module):
    """ infers model device as the device where the majority of parameters and buffers are stored """
    device_stats = Counter(
        tensor.device for tensor in chain(model.parameters(), model.buffers())
        if torch.is_tensor(tensor))
    return max(device_stats, key=device_stats.get)
Exemplo n.º 13
0
def train(args, model: nn.Module, criterion, *, params,
          train_loader, valid_loader, init_optimizer, use_cuda,
          n_epochs=None, patience=4, max_lr_changes=2) -> bool:
    lr = args.lr
    n_epochs = n_epochs or args.n_epochs
    params = list(params)
    optimizer = init_optimizer(args.optimizer, params, lr)

    run_root = Path(args.run_root)
    model_path = run_root / 'model-1.pt'
    best_model_path = run_root / 'best-model.pt'
    best_valid_loss = 0.0
    if model_path.exists():
        state = load_model(model, model_path)
        epoch = state['epoch']
        step = state['step']
    else:
        epoch = 1
        step = 0
        best_valid_loss = float('inf')
    lr_changes = 0

    save = lambda ep, save_name: torch.save({
        'model': model.state_dict(),
        'epoch': ep,
        'step': step,
        'best_valid_loss': best_valid_loss
    }, str(run_root / save_name))

    report_each = 10
    log = run_root.joinpath('train.log').open('at', encoding='utf8')
    valid_losses = []
    lr_reset_epoch = epoch
    for epoch in range(epoch, n_epochs + 1):
        model.train()
        tq = tqdm.tqdm(total=(args.epoch_size or
                              len(train_loader) * args.batch_size))
        tq.set_description(f'Epoch {epoch}, lr {lr}')
        losses = []
        tl = train_loader
        if args.epoch_size:
            tl = islice(tl, args.epoch_size // args.batch_size)
        try:
            mean_loss = 0
            for i, (inputs, targets) in enumerate(tl):
                if use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = _reduce_loss(criterion(outputs, targets))

                batch_size = inputs.size(0)
                (batch_size * loss).backward()

                if (i + 1) % args.step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-report_each:])
                tq.set_postfix(loss=f'{mean_loss:.3f}')
                # if i and i % report_each == 0:
                #     write_event(log, step, loss=mean_loss)
            tq.close()
            save(epoch + 1, f'model-{epoch}.pt')
            valid_metrics = validation(model, criterion, valid_loader, use_cuda, args.model)
            write_event(log, step, epoch, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                save(epoch + 1, 'best-model.pt')
            elif (patience and epoch - lr_reset_epoch > patience and
                  min(valid_losses[-patience:]) > best_valid_loss):
                # "patience" epochs without improvement
                lr_changes +=1
                if lr_changes > max_lr_changes:
                    break
                lr /= 5
                print(f'lr updated to {lr}')
                lr_reset_epoch = epoch

                optimizer = init_optimizer(args.optimizer, params, lr)
                
        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save(epoch, 'model-interrupted.pt')
            print('done.')
            return False
    return True
Exemplo n.º 14
0
def annotate_video(movie_file_path: str,
                   dataset_path: str,
                   output_path: str,
                   model: nn.Module,
                   device,
                   max_frame: int = 100000,
                   tracker_max_age: int = 10,
                   plotter: utils.plotter_utils.VisdomPlotter = None,
                   name: str = '',
                   compute_track_mean: bool = False):

    filename = os.path.join(dataset_path, 'bbx.txt')
    print('Getting annotations from {}'.format(filename))
    bbx_list = utils.read_file_to_list(filename)

    if bbx_list:
        bounding_boxes_list = bbx_list
    else:
        bounding_boxes_list = get_bounding_boxes(movie_file_path,
                                                 max_frame=max_frame,
                                                 tracker_max_age=tracker_max_age)

    print('Extracting ROI of the video.')
    cropped_image_list = get_cropped_images(movie_file_path,
                                            bounding_boxes_list,
                                            max_frame=max_frame)

    track_dict = get_track_dict(bounding_boxes_list)
    frame_dict = get_frame_dict(bounding_boxes_list)
    bbx_dict = get_bbx_dict(bounding_boxes_list)

    # Data transform
    data_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    dataset = NumpyDataset(cropped_image_list,
                           transform=data_transform)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             num_workers=2,
                                             batch_size=100)

    print('Extracting features.')
    model = model.to(device)
    features = ml_utils.extract_features(dataloader,
                                         model,
                                         device)
    cluster_techniques_list = ['kmeans', 'spectral', 'hac']

    tsne_features, tsne_chosen_samples = projection_utils.tsne_projection(features)
    pca_features, pca_chosen_samples = projection_utils.pca_projection(features)

    # Frame level clustering
    print('Performing frame level clustering.')
    for cluster_method in cluster_techniques_list:
        cluster_name = '{}_frame_level_{}'.format(name, cluster_method)
        predictions, data_dict = clustering.cluster_techniques(features,
                                                     cluster_method,
                                                     max_clusters=10)

        write_video(movie_file_path,
                    output_path,
                    predictions,
                    frame_dict,
                    name=cluster_name,
                    max_frame=max_frame)

        plotter.scatter_plot(cluster_name + '_tsne',
                             tsne_features,
                             predictions[tsne_chosen_samples])
        plotter.scatter_plot(cluster_name + '_pca',
                             pca_features,
                             predictions[pca_chosen_samples])

    # Add ground truth if it exist
    gt_file_path = os.path.join(dataset_path, 'bbx_gt.txt')
    if os.path.isfile(gt_file_path):
        print('Creating ground truth video and plots.')
        bbx_to_gt_list = utils.read_file_to_list(gt_file_path)
        bbx_to_gt_dict = utils.list_to_dict(bbx_to_gt_list)

        groundtruth = []
        gt_to_idx_dict = {}
        bbx_count = 0
        for bbx in bounding_boxes_list:
            bbx_idx = bbx[2]
            gt = bbx_to_gt_dict[bbx_idx]
            if gt not in gt_to_idx_dict.keys():
                gt_to_idx_dict[gt] = bbx_count
                bbx_count += 1
            label = gt_to_idx_dict[gt]
            groundtruth.append(label)
        groundtruth = np.array(groundtruth)

        gt_name = '{}_gt'.format(name)
        write_video(movie_file_path,
                    output_path,
                    groundtruth,
                    frame_dict,
                    name=gt_name,
                    max_frame=max_frame)

        plotter.scatter_plot(gt_name + '_tsne',
                             tsne_features,
                             groundtruth[tsne_chosen_samples])
        plotter.scatter_plot(gt_name + '_pca',
                             pca_features,
                             groundtruth[pca_chosen_samples])

    # Track level clustering
    if compute_track_mean:
        print('Performing track level clustering.')

        mean_features = []
        track_to_idx_dict = {}
        for idx, track_idx in enumerate(track_dict.keys()):
            feature_track = features[track_dict[track_idx]]
            mean_features.append(np.mean(feature_track, axis=0))
            track_to_idx_dict[track_idx] = idx
        mean_features = np.asarray(mean_features)

        for cluster_method in cluster_techniques_list:
            cluster_name = '{}_track_level_{}'.format(name, cluster_method)
            mean_predictions, data_dict = clustering.cluster_techniques(mean_features,
                                                                        cluster_method,
                                                                        max_clusters=10)
            predictions = []
            for bbx_idx in bbx_dict.keys():
                track_idx = track_to_idx_dict[bbx_dict[bbx_idx][0]]
                predictions.append(mean_predictions[track_idx])
            predictions = np.array(predictions)

            write_video(movie_file_path,
                        output_path,
                        predictions,
                        frame_dict,
                        name=cluster_name,
                        max_frame=max_frame)

            plotter.scatter_plot(cluster_name + '_tsne',
                                 tsne_features,
                                 predictions[tsne_chosen_samples])
            plotter.scatter_plot(cluster_name + '_pca',
                                 pca_features,
                                 predictions[pca_chosen_samples])
Exemplo n.º 15
0
def efficientnet_init_weights(model: nn.Module, init_fn=None):
    init_fn = init_fn or _init_weight_goog
    for n, m in model.named_modules():
        init_fn(m, n)
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
Exemplo n.º 17
0
def _layer_flops(layer: nn.Module, layer_args: List[Any], y: Any) -> int:
    """
    Computes the number of FLOPs required for a single layer.

    For common layers, such as Conv1d, the flop compute is implemented in this
    centralized place.
    For other layers, if it defines a method to compute flops with the signature
    below, we will use it to compute flops.

    Class MyModule(nn.Module):
        def flops(self, x):
            ...

    """

    x = layer_args[0]
    # get layer type:
    typestr = layer.__repr__()
    layer_type = typestr[: typestr.find("(")].strip()
    batchsize_per_replica = get_batchsize_per_replica(x)

    flops = None
    # 1D convolution:
    if layer_type in ["Conv1d"]:
        # x shape is N x C x W
        out_w = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            / layer.stride[0]
            + 1
        )
        flops = (
            batchsize_per_replica
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * out_w
            / layer.groups
        )
    # 2D convolution:
    elif layer_type in ["Conv2d"]:
        out_h = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            / layer.stride[0]
            + 1
        )
        out_w = int(
            (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
            / layer.stride[1]
            + 1
        )
        flops = (
            batchsize_per_replica
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * layer.kernel_size[1]
            * out_h
            * out_w
            / layer.groups
        )

    # learned group convolution:
    elif layer_type in ["LearnedGroupConv"]:
        conv = layer.conv
        out_h = int(
            (x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / conv.stride[0]
            + 1
        )
        out_w = int(
            (x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / conv.stride[1]
            + 1
        )
        count1 = _layer_flops(layer.relu, x) + _layer_flops(layer.norm, x)
        count2 = (
            batchsize_per_replica
            * conv.in_channels
            * conv.out_channels
            * conv.kernel_size[0]
            * conv.kernel_size[1]
            * out_h
            * out_w
            / layer.condense_factor
        )
        flops = count1 + count2

    # non-linearities:
    elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax", "SiLU"]:
        flops = x.numel()

    # 2D pooling layers:
    elif layer_type in ["AvgPool2d", "MaxPool2d"]:
        in_h = x.size()[2]
        in_w = x.size()[3]
        if isinstance(layer.kernel_size, int):
            layer.kernel_size = (layer.kernel_size, layer.kernel_size)
        kernel_ops = layer.kernel_size[0] * layer.kernel_size[1]
        out_h = 1 + int(
            (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride
        )
        out_w = 1 + int(
            (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride
        )
        flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops

    # adaptive avg pool2d
    # This is approximate and works only for downsampling without padding
    # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp
    elif layer_type in ["AdaptiveAvgPool2d"]:
        in_h = x.size()[2]
        in_w = x.size()[3]
        if isinstance(layer.output_size, int):
            out_h, out_w = layer.output_size, layer.output_size
        elif len(layer.output_size) == 1:
            out_h, out_w = layer.output_size[0], layer.output_size[0]
        else:
            out_h, out_w = layer.output_size
        if out_h > in_h or out_w > in_w:
            raise ClassyProfilerNotImplementedError(layer)
        batchsize_per_replica = x.size()[0]
        num_channels = x.size()[1]
        kh = in_h - out_h + 1
        kw = in_w - out_w + 1
        kernel_ops = kh * kw
        flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops

    # linear layer:
    elif layer_type in ["Linear"]:
        weight_ops = layer.weight.numel()
        bias_ops = layer.bias.numel() if layer.bias is not None else 0
        flops = ((x.numel() / x.size(-1)) if x.ndim > 2 else x.size(0)) * (
            weight_ops + bias_ops
        )

    # batch normalization / layer normalization:
    elif layer_type in [
        "BatchNorm1d",
        "BatchNorm2d",
        "BatchNorm3d",
        "SyncBatchNorm",
        "LayerNorm",
    ]:
        flops = 2 * x.numel()

    # 3D convolution
    elif layer_type in ["Conv3d"]:
        out_t = int(
            (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
            // layer.stride[0]
            + 1
        )
        out_h = int(
            (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
            // layer.stride[1]
            + 1
        )
        out_w = int(
            (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
            // layer.stride[2]
            + 1
        )
        flops = (
            batchsize_per_replica
            * layer.in_channels
            * layer.out_channels
            * layer.kernel_size[0]
            * layer.kernel_size[1]
            * layer.kernel_size[2]
            * out_t
            * out_h
            * out_w
            / layer.groups
        )

    # 3D pooling layers
    elif layer_type in ["AvgPool3d", "MaxPool3d"]:
        in_t = x.size()[2]
        in_h = x.size()[3]
        in_w = x.size()[4]
        if isinstance(layer.kernel_size, int):
            layer.kernel_size = (
                layer.kernel_size,
                layer.kernel_size,
                layer.kernel_size,
            )
        if isinstance(layer.padding, int):
            layer.padding = (layer.padding, layer.padding, layer.padding)
        if isinstance(layer.stride, int):
            layer.stride = (layer.stride, layer.stride, layer.stride)
        kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2]
        out_t = 1 + int(
            (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0]
        )
        out_h = 1 + int(
            (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1]
        )
        out_w = 1 + int(
            (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2]
        )
        flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops

    # adaptive avg pool3d
    # This is approximate and works only for downsampling without padding
    # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
    elif layer_type in ["AdaptiveAvgPool3d"]:
        in_t = x.size()[2]
        in_h = x.size()[3]
        in_w = x.size()[4]
        out_t = layer.output_size[0]
        out_h = layer.output_size[1]
        out_w = layer.output_size[2]
        if out_t > in_t or out_h > in_h or out_w > in_w:
            raise ClassyProfilerNotImplementedError(layer)
        batchsize_per_replica = x.size()[0]
        num_channels = x.size()[1]
        kt = in_t - out_t + 1
        kh = in_h - out_h + 1
        kw = in_w - out_w + 1
        kernel_ops = kt * kh * kw
        flops = (
            batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops
        )

    elif layer_type in ["Dropout", "Identity"]:
        flops = 0

    elif hasattr(layer, "flops"):
        # If the module already defines a method to compute flops with the signature
        # below, we use it to compute flops
        #
        #   Class MyModule(nn.Module):
        #     def flops(self, x):
        #       ...
        #   or
        #
        #   Class MyModule(nn.Module):
        #     def flops(self, x1, x2):
        #       ...
        flops = layer.flops(*layer_args)

    if flops is None:
        raise ClassyProfilerNotImplementedError(layer)

    message = [
        f"module type: {typestr}",
        f"input size: {get_shape(x)}",
        f"output size: {get_shape(y)}",
        f"params(M): {count_params(layer) / 1e6}",
        f"flops(M): {int(flops) / 1e6}",
    ]
    logging.debug("\t".join(message))
    return int(flops)
Exemplo n.º 18
0
    def test(self,
             net: nn.Module,
             clean_data: CSVDataset,
             triggered_data: CSVDataset,
             clean_test_triggered_labels_data: CSVDataset,
             progress_bar_disable: bool = False,
             torch_dataloader_kwargs: dict = None) -> dict:
        """
        Test the trained network
        :param net: the trained module to run the test data through
        :param clean_data: the clean Dataset
        :param triggered_data: the triggered Dataset, if None, not computed
        :param clean_test_triggered_labels_data: triggered part of the training dataset but with correct labels; see
            DataManger.load_data for more information.
        :param progress_bar_disable: if True, disables the progress bar
        :param torch_dataloader_kwargs: any keyword arguments to pass directly to PyTorch's DataLoader
        :return: a dictionary of the statistics on the clean and triggered data (if applicable)
        """
        test_data_statistics = {}
        net.eval()

        pin_memory = False
        if self.device.type != 'cpu':
            pin_memory = True

        # drop_last=True is from: https://stackoverflow.com/questions/56576716
        data_loader_kwargs_in = dict(batch_size=1,
                                     pin_memory=pin_memory,
                                     drop_last=True,
                                     shuffle=False)
        if torch_dataloader_kwargs:
            data_loader_kwargs_in.update(torch_dataloader_kwargs)
        logger.info('DataLoader[Test] kwargs=' + str(torch_dataloader_kwargs))
        data_loader = DataLoader(clean_data, **data_loader_kwargs_in)

        # Test the classification accuracy on clean data only, for all labels.
        test_acc, test_n_total, test_n_correct, _ = _eval_acc(
            data_loader, net, self.device, self.soft_to_hard_fn,
            self.soft_to_hard_fn_kwargs)
        test_data_statistics['clean_accuracy'] = test_acc
        test_data_statistics['clean_n_total'] = test_n_total
        logger.info("Accuracy on clean test data: %0.02f" %
                    (test_data_statistics['clean_accuracy'], ))

        if triggered_data is not None:
            # Test the classification accuracy on triggered data only, for all labels.
            # we set batch_size=1 b/c
            data_loader = DataLoader(triggered_data,
                                     batch_size=1,
                                     pin_memory=pin_memory)
            test_acc, test_n_total, test_n_correct, _ = _eval_acc(
                data_loader, net, self.device, self.soft_to_hard_fn,
                self.soft_to_hard_fn_kwargs)
            test_data_statistics['triggered_accuracy'] = test_acc
            test_data_statistics['triggered_n_total'] = test_n_total
            logger.info("Accuracy on triggered test data: %0.02f for n=%s" %
                        (test_data_statistics['triggered_accuracy'],
                         str(test_n_total)))

        if clean_test_triggered_labels_data is not None:
            # Test the classification accuracy on clean data for labels which have corresponding triggered examples.
            # For example, if an MNIST dataset was created with triggered examples only for labels 4 and 5,
            # then this dataset is the subset of data with labels 4 and 5 that don't have the triggers.
            data_loader = DataLoader(clean_test_triggered_labels_data,
                                     batch_size=1,
                                     pin_memory=pin_memory)
            test_acc, test_n_total, test_n_correct, _ = _eval_acc(
                data_loader, net, self.device, self.soft_to_hard_fn,
                self.soft_to_hard_fn_kwargs)
            test_data_statistics[
                'clean_test_triggered_label_accuracy'] = test_acc
            test_data_statistics[
                'clean_test_triggered_label_n_total'] = test_n_total
            logger.info(
                "Accuracy on clean-data-triggered-labels: %0.02f for n=%s" %
                (test_data_statistics['clean_test_triggered_label_accuracy'],
                 str(test_n_total)))

        return test_data_statistics
Exemplo n.º 19
0
def summary(model: nn.Module,
            input_data: INPUT_DATA_TYPE = None,
            *args: Any,
            batch_dim: Optional[int] = 0,
            branching: bool = True,
            col_names: Optional[Sequence[str]] = None,
            col_width: int = 25,
            depth: int = 3,
            device: Optional[torch.device] = None,
            dtypes: Optional[List[torch.dtype]] = None,
            verbose: int = 1,
            print_step: bool = True,
            print_func=print,
            **kwargs: Any) -> ModelStatistics:
    """
    Summarize the given PyTorch model. Summarized information includes:
        1) Layer names,
        2) input/output shapes,
        3) kernel shape,
        4) # of parameters,
        5) # of operations (Mult-Adds)

    Args:
        model (nn.Module):
                PyTorch model to summarize

        input_data (Sequence of Sizes or Tensors):
                Example input tensor of the model (dtypes inferred from model input).
                - OR -
                Shape of input data as a List/Tuple/torch.Size (dtypes must match model input,
                default is FloatTensors). Should NOT include batch size in the tuple.
                - OR -
                If input_data is not provided, no forward pass through the network is performed,
                and the provided model information is limited to layer names.

        batch_dim (int):
                Batch_dimension of input data. Default: 0
                If batch_dim is None, the input data is assumed to contain the batch dimension.
                WARNING: in a future version of torch-summary, the default will change to None.

        branching (bool):
                Whether to use the branching layout for the printed output. Default: True

        col_names (Sequence[str]):
                Specify which columns to show in the output. Currently supported:
                        ("input_size", "output_size", "num_params", "kernel_size", "mult_adds")
                If input_data is not provided, only "num_params" is used.
                Default: ("output_size", "num_params")

        col_width (int):
                Width of each column. Default: 25

        depth (int):
                Number of nested layers to traverse (e.g. Sequentials). Default: 3

        device (torch.Device):
                Uses this torch device for model and input_data.
                If not specified, uses result of torch.cuda.is_available(). Default: None

        dtypes (List[torch.dtype]):
                For multiple inputs, specify the size of both inputs, and
                also specify the types of each parameter here. Default: None

        verbose (int):
                0 (quiet): No output
                1 (default): Print model summary
                2 (verbose): Show weight and bias layers in full detail
                Default: 1

        *args, **kwargs:
                Other arguments used in `model.forward` function.

    Return:
        ModelStatistics object
                See torchsummary/model_statistics.py for more information.
    """
    if col_names is None:
        col_names = (
            "num_params", ) if input_data is None else DEFAULT_COLUMN_NAMES

    validate_user_params(input_data, col_names, verbose)
    input_size = []  # type: CORRECTED_INPUT_SIZE_TYPE
    summary_list = []  # type: List[LayerInfo]
    hooks = None if input_data is None else [
    ]  # type: Optional[List[RemovableHandle]]
    idx = {}  # type: Dict[int, int]
    apply_hooks(model, model, batch_dim, depth, summary_list, idx, hooks)

    if input_data is not None:
        if device is None:
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")

        x, input_size = process_input_data(input_data, batch_dim, device,
                                           dtypes)
        args, kwargs = set_device(args, device), set_device(kwargs, device)
        try:
            with torch.no_grad():
                _ = model.to(device)(*x, *args, **kwargs)  # type: ignore
        except Exception:
            executed_layers = [
                layer for layer in summary_list if layer.executed
            ]
            print_func(
                "Failed to run torchsummary, executed layers up to: {}".format(
                    executed_layers))
            raise
        finally:
            if hooks is not None:
                for hook in hooks:
                    hook.remove()

    formatting = FormattingOptions(branching, depth, verbose, col_names,
                                   col_width)
    formatting.set_layer_name_width(summary_list)
    results = ModelStatistics(summary_list, input_size, formatting, print_step)
    if verbose > Verbosity.QUIET.value:
        print_func(results)
    return results
Exemplo n.º 20
0
def evaluate_classifier(model: nn.Module,
                        test_dl: DataLoader,
                        loss_func: Callable,
                        classes: List[int] = [0, 1]) -> None:
    "evaluate a pytorch graph model for classification"

    y_pred = []
    y_true = []
    y_prob = []
    prob_arr = []

    test_loss = 0

    for bg, labels in test_dl:
        model.eval()

        bg.set_e_initializer(dgl.init.zero_initializer)
        bg.set_n_initializer(dgl.init.zero_initializer)

        logit = model(bg)
        probs = torch.softmax(logit, 1).detach().numpy()
        prob_arr.append(probs)
        predictions = np.argmax(probs, 1)

        y_pred += list(predictions)
        y_true += list(labels)
        y_prob += list(probs[:, 1])

        loss = loss_func(logit, labels)
        test_loss += loss.detach().item()

    print('test_loss: ', test_loss / len(test_dl))
    print('accuracy: ', accuracy_score(y_true, y_pred))
    print('classification report: \n', classification_report(y_true, y_pred))

    if len(classes) == 2:
        print('roc-auc: ', roc_auc_score(y_true, y_prob))
        print('bootstrapped roc-auc: ', bs_roc_auc_score(y_true, y_prob))

    else:
        y_test = label_binarize(y_true, classes=classes)
        n_classes = y_test.shape[1]

        prob_arr = np.concatenate([x for x in prob_arr], axis=0)

        # Compute ROC curve and ROC area for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        bs_roc_auc = dict()
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_test[:, i], prob_arr[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
            bs_roc_auc[i] = bs_roc_auc_score(y_test[:, i], prob_arr[:, i])

        fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(),
                                                  prob_arr.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
        bs_roc_auc['micro'] = bs_roc_auc_score(y_test.ravel(),
                                               prob_arr.ravel())

        print("micro auc score and score for each class: ")
        for key in roc_auc:
            print(key, ' : ', roc_auc[key])
        print("bootstrapped micro auc score and score for each class: ")
        for key in bs_roc_auc:
            print(key, ' : ', bs_roc_auc[key])
Exemplo n.º 21
0
 def free(self, module: nn.Module):
     for p in module.parameters():
         p.requires_grad = True
Exemplo n.º 22
0
 def _build_opt(self, model: nn.Module) -> optim.Optimizer:
     if isinstance(self.opt, str):
         self.opt = self._interp_opt(
             self.opt)  # Backwards compatability with pre-v0.3.1 saves
     return self.opt(model.parameters(), **self.opt_args)
Exemplo n.º 23
0
 def frozen(self, module: nn.Module):
     for p in module.parameters():
         p.requires_grad = False
Exemplo n.º 24
0
def train_step(
    model: nn.Module,
    train_loader,
    criterion,
    device: str,
    optimizer,
    scheduler=None,
    num_batches: int = None,
    log_interval: int = 100,
    scaler=None,
):
    """
    Performs one step of training. Calculates loss, forward pass, computes gradient and returns metrics.
    Args:
        model : PyTorch Detr Model.
        train_loader : Train loader.
        device : "cuda" or "cpu"
        criterion : Detr Loss function to be optimized.
        optimizer : Torch optimizer to train.
        scheduler : Learning rate scheduler.
        num_batches : (optional) Integer To limit training to certain number of batches.
        log_interval : (optional) Defualt 100. Integer to Log after specified batch ids in every batch.
        scaler: (optional)  Pass torch.cuda.amp.GradScaler() for fp16 precision Training.
    """

    model = model.to(device)
    criterion = criterion.to(device)
    start_train_step = time.time()
    model.train()
    last_idx = len(train_loader) - 1
    batch_time_m = utils.AverageMeter()
    criterion.train()
    cnt = 0
    batch_start = time.time()
    metrics = OrderedDict()

    total_loss = utils.AverageMeter()
    bbox_loss = utils.AverageMeter()
    giou_loss = utils.AverageMeter()
    labels_loss = utils.AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        last_batch = batch_idx == last_idx
        images = list(image.to(device) for image in inputs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        if scaler is not None:
            with amp.autocast():
                outputs = model(images)
                loss_dict = criterion(outputs, targets)
                weight_dict = criterion.weight_dict
                loss = sum(loss_dict[k] * weight_dict[k]
                           for k in loss_dict.keys() if k in weight_dict)
                scaler.scale(loss).backward()
                # Step using scaler.step()
                scaler.step(optimizer)
                # Update for next iteration
                scaler.update()

        else:
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            weight_dict = criterion.weight_dict
            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                       if k in weight_dict)
            loss.backward()
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        cnt += 1
        total_loss.update(loss.item())
        bbox_loss.update(loss_dict["loss_bbox"].item())
        giou_loss.update(loss_dict["loss_giou"].item())
        labels_loss.update(loss_dict["loss_ce"].item())

        batch_time_m.update(time.time() - batch_start)
        batch_start = time.time()

        if last_batch or batch_idx % log_interval == 0:  # If we reach the log intervel
            print(
                "Batch Train Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  "
                .format(batch_time=batch_time_m, ))

        if num_batches is not None:
            if cnt >= num_batches:
                end_train_step = time.time()
                metrics["total_loss"] = total_loss.avg
                metrics["bbox_loss"] = bbox_loss.avg
                metrics["giou_loss"] = giou_loss.avg
                metrics["labels_loss"] = labels_loss.avg

                print(f"Done till {num_batches} train batches")
                print(
                    f"Time taken for Training step = {end_train_step - start_train_step} sec"
                )
                return metrics

    end_train_step = time.time()
    metrics["total_loss"] = total_loss.avg
    metrics["bbox_loss"] = bbox_loss.avg
    metrics["giou_loss"] = giou_loss.avg
    metrics["labels_loss"] = labels_loss.avg
    print(
        f"Time taken for Training step = {end_train_step - start_train_step} sec"
    )
    return metrics
Exemplo n.º 25
0
def initialize_model(model: nn.Module, cfg: dict, padding_idx: int) -> None:
    """
    This initializes a model based on the provided config.

    All initializer configuration is part of the `model` section of the
    configuration file.
    For an example, see e.g. `https://github.com/joeynmt/joeynmt/
    blob/master/configs/iwslt_envi_xnmt.yaml#L47`

    The main initializer is set using the `initializer` key.
    Possible values are `xavier`, `uniform`, `normal` or `zeros`.
    (`xavier` is the default).

    When an initializer is set to `uniform`, then `init_weight` sets the
    range for the values (-init_weight, init_weight).

    When an initializer is set to `normal`, then `init_weight` sets the
    standard deviation for the weights (with mean 0).

    The word embedding initializer is set using `embed_initializer` and takes
    the same values. The default is `normal` with `embed_init_weight = 0.01`.

    Biases are initialized separately using `bias_initializer`.
    The default is `zeros`, but you can use the same initializers as
    the main initializer.

    Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization
    (for recurrent matrices). Default is False.

    `lstm_forget_gate` controls how the LSTM forget gate is initialized.
    Default is `1`.

    :param model: model to initialize
    :param cfg: the model configuration
    :param padding_idx:
    """

    # defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal
    gain = float(cfg.get("init_gain", 1.0))  # for xavier
    init = cfg.get("initializer", "xavier")
    init_weight = float(cfg.get("init_weight", 0.01))

    embed_init = cfg.get("embed_initializer", "normal")
    embed_init_weight = float(cfg.get("embed_init_weight", 0.01))
    embed_gain = float(cfg.get("embed_init_gain", 1.0))  # for xavier

    bias_init = cfg.get("bias_initializer", "zeros")
    bias_init_weight = float(cfg.get("bias_init_weight", 0.01))

    init_fn_ = _parse_init(init, init_weight, gain)
    embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain)
    bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain)

    with torch.no_grad():
        for name, p in model.named_parameters():
            if "embed" in name:
                embed_init_fn_(p)
                # zero out paddings; assumes all fields have same pad
                p.data[padding_idx].zero_()

            elif "bias" in name:
                if bias_init != "xavier":
                    bias_init_fn_(p)

            elif len(p.size()) > 1:

                # RNNs combine multiple matrices is one, which messes up
                # xavier initialization
                if init == "xavier" and "rnn" in name:
                    if "encoders" in name:
                        rnn = next(iter(model.encoders.values())).rnn
                    elif "encoder" in name:  # matches "encoders" too...
                        rnn = model.encoder.rnn
                    elif "decoder" in name:
                        rnn = model.decoder.rnn
                    else:
                        rnn = None

                    if rnn is not None:
                        n = 4 if isinstance(rnn, nn.LSTM) else 3
                    else:
                        n = 1  # when would this come up?
                    xavier_uniform_n_(p.data, gain=gain, n=n)
                else:
                    init_fn_(p)

        orthogonal = cfg.get("init_rnn_orthogonal", False)
        lstm_forget_gate = cfg.get("lstm_forget_gate", 1.)

        # encoder rnn orthogonal initialization & LSTM forget gate
        if hasattr(model, "encoders"):
            encoders = list(model.encoders.values())
        else:
            encoders = [model.encoder]
        for encoder in encoders:
            if hasattr(encoder, "rnn"):

                if orthogonal:
                    orthogonal_rnn_init_(encoder.rnn)

                if isinstance(encoder.rnn, nn.LSTM):
                    lstm_forget_gate_init_(encoder.rnn, lstm_forget_gate)

        # decoder rnn orthogonal initialization & LSTM forget gate
        if hasattr(model.decoder, "rnn"):

            if orthogonal:
                orthogonal_rnn_init_(model.decoder.rnn)

            if isinstance(model.decoder.rnn, nn.LSTM):
                lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate)
Exemplo n.º 26
0
def save_checkpoint(model: nn.Module, buffer: SampleBuffer, save_dir: Path, tag):
    checkpoint_dict = {"model_state": model.state_dict(), "sample_buffer": buffer}
    torch.save(checkpoint_dict, save_dir / tag)
Exemplo n.º 27
0
def module_to_array(
    module: Module,
    bounds: Optional[ParameterBounds] = None,
    exclude: Optional[Set[str]] = None,
) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]:
    r"""Extract named parameters from a module into a numpy array.

    Only extracts parameters with requires_grad, since it is meant for optimizing.

    Args:
        module: A module with parameters. May specify parameter constraints in
            a `named_parameters_and_constraints` method.
        bounds: A ParameterBounds dictionary mapping parameter names to tuples
            of lower and upper bounds. Bounds specified here take precedence
            over bounds on the same parameters specified in the constraints
            registered with the module.
        exclude: A list of parameter names that are to be excluded from extraction.

    Returns:
        3-element tuple containing
        - The parameter values as a numpy array.
        - An ordered dictionary with the name and tensor attributes of each
        parameter.
        - A `2 x n_params` numpy array with lower and upper bounds if at least
        one constraint is finite, and None otherwise.

    Example:
        >>> mll = ExactMarginalLogLikelihood(model.likelihood, model)
        >>> parameter_array, property_dict, bounds_out = module_to_array(mll)
    """
    x: List[np.ndarray] = []
    lower: List[np.ndarray] = []
    upper: List[np.ndarray] = []
    property_dict = OrderedDict()
    exclude = set() if exclude is None else exclude

    # get bounds specified in model (if any)
    bounds_: ParameterBounds = {}
    if hasattr(module, "named_parameters_and_constraints"):
        for param_name, _, constraint in module.named_parameters_and_constraints():
            if constraint is not None and not constraint.enforced:
                bounds_[param_name] = constraint.lower_bound, constraint.upper_bound

    # update with user-supplied bounds (overwrites if already exists)
    if bounds is not None:
        bounds_.update(bounds)

    for p_name, t in module.named_parameters():
        if p_name not in exclude and t.requires_grad:
            property_dict[p_name] = TorchAttr(
                shape=t.shape, dtype=t.dtype, device=t.device
            )
            x.append(t.detach().view(-1).cpu().double().clone().numpy())
            # construct bounds
            if bounds_:
                l_, u_ = bounds_.get(p_name, (-inf, inf))
                if torch.is_tensor(l_):
                    l_ = l_.cpu().detach()
                if torch.is_tensor(u_):
                    u_ = u_.cpu().detach()
                # check for Nones here b/c it may be passed in manually in bounds
                lower.append(np.full(t.nelement(), l_ if l_ is not None else -inf))
                upper.append(np.full(t.nelement(), u_ if u_ is not None else inf))

    x_out = np.concatenate(x)
    bounds_out = None
    if bounds_:
        if not all(np.isinf(b).all() for lu in (lower, upper) for b in lu):
            bounds_out = np.stack([np.concatenate(lower), np.concatenate(upper)])
    return x_out, property_dict, bounds_out
Exemplo n.º 28
0
    def _training_step(self, model: nn.Module,
                       inputs: Dict[str, torch.Tensor]) -> float:

        loss = model.train_step(**inputs)

        return loss
Exemplo n.º 29
0
 def __getattr__(self, attr):
     if self._has_method(attr):
         return self._get_method(attr)
     return Module.__getattr__(self, attr)
Exemplo n.º 30
0
 def __getattr__(self, attr):
     if self._has_method(attr):
         return self._get_method(attr)
     return Module.__getattr__(self, attr)
Exemplo n.º 31
0
 def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None:
     for source_param, target_param in zip(source.parameters(), target.parameters()):
         target_param.data.copy_(
             target_param.data * (1.0 - tau) + source_param.data * tau
         )
def initialize_model(model: nn.Module, cfg: dict, txt_padding_idx: int) -> None:
    """
    This initializes a model based on the provided config.

    All initializer configuration is part of the `model` section of the
    configuration file.
    For an example, see e.g. `https://github.com/joeynmt/joeynmt/
    blob/master/configs/iwslt_envi_xnmt.yaml#L47`

    The main initializer is set using the `initializer` key.
    Possible values are `xavier`, `uniform`, `normal` or `zeros`.
    (`xavier` is the default).

    When an initializer is set to `uniform`, then `init_weight` sets the
    range for the values (-init_weight, init_weight).

    When an initializer is set to `normal`, then `init_weight` sets the
    standard deviation for the weights (with mean 0).

    The word embedding initializer is set using `embed_initializer` and takes
    the same values. The default is `normal` with `embed_init_weight = 0.01`.

    Biases are initialized separately using `bias_initializer`.
    The default is `zeros`, but you can use the same initializers as
    the main initializer.

    Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization
    (for recurrent matrices). Default is False.

    `lstm_forget_gate` controls how the LSTM forget gate is initialized.
    Default is `1`.

    :param model: model to initialize
    :param cfg: the model configuration
    :param txt_padding_idx: index of spoken language text padding token
    """

    # defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal
    gain = float(cfg.get("init_gain", 1.0))  # for xavier
    init = cfg.get("initializer", "xavier")
    init_weight = float(cfg.get("init_weight", 0.01))

    embed_init = cfg.get("embed_initializer", "normal")
    embed_init_weight = float(cfg.get("embed_init_weight", 0.01))
    embed_gain = float(cfg.get("embed_init_gain", 1.0))  # for xavier

    bias_init = cfg.get("bias_initializer", "zeros")
    bias_init_weight = float(cfg.get("bias_init_weight", 0.01))

    # pylint: disable=unnecessary-lambda, no-else-return
    def _parse_init(s, scale, _gain):
        scale = float(scale)
        assert scale > 0.0, "incorrect init_weight"
        if s.lower() == "xavier":
            return lambda p: nn.init.xavier_uniform_(p, gain=_gain)
        elif s.lower() == "uniform":
            return lambda p: nn.init.uniform_(p, a=-scale, b=scale)
        elif s.lower() == "normal":
            return lambda p: nn.init.normal_(p, mean=0.0, std=scale)
        elif s.lower() == "zeros":
            return lambda p: nn.init.zeros_(p)
        else:
            raise ValueError("unknown initializer")

    init_fn_ = _parse_init(init, init_weight, gain)
    embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain)
    bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain)

    with torch.no_grad():
        for name, p in model.named_parameters():

            if "txt_embed" in name:
                if "lut" in name:
                    embed_init_fn_(p)

            elif "bias" in name:
                bias_init_fn_(p)

            elif len(p.size()) > 1:

                # RNNs combine multiple matrices is one, which messes up
                # xavier initialization
                if init == "xavier" and "rnn" in name:
                    n = 1
                    if "encoder" in name:
                        n = 4 if isinstance(model.encoder.rnn, nn.LSTM) else 3
                    elif "decoder" in name:
                        n = 4 if isinstance(model.decoder.rnn, nn.LSTM) else 3
                    xavier_uniform_n_(p.data, gain=gain, n=n)
                else:
                    init_fn_(p)

        # zero out paddings
        if model.txt_embed is not None:
            model.txt_embed.lut.weight.data[txt_padding_idx].zero_()

        orthogonal = cfg.get("init_rnn_orthogonal", False)
        lstm_forget_gate = cfg.get("lstm_forget_gate", 1.0)

        # encoder rnn orthogonal initialization & LSTM forget gate
        if hasattr(model.encoder, "rnn"):

            if orthogonal:
                orthogonal_rnn_init_(model.encoder.rnn)

            if isinstance(model.encoder.rnn, nn.LSTM):
                lstm_forget_gate_init_(model.encoder.rnn, lstm_forget_gate)

        # decoder rnn orthogonal initialization & LSTM forget gate
        if hasattr(model.decoder, "rnn"):

            if orthogonal:
                orthogonal_rnn_init_(model.decoder.rnn)

            if isinstance(model.decoder.rnn, nn.LSTM):
                lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate)
Exemplo n.º 33
0
 def __dir__(self):
     return sorted(Module.__dir__(self) + self._method_names())
Exemplo n.º 34
0
def adjust_weight_to_zero(model: nn.Module, thresh):
    """If the value < 1e-6, it's set to 0"""
    for name, param in model.named_parameters():
        if 'weight' in name:
            mask_value(param, thresh)
Exemplo n.º 35
0
    def train_epoch(self,
                    model: nn.Module,
                    train_loader: DataLoader,
                    val_clean_loader: DataLoader,
                    val_triggered_loader: DataLoader,
                    epoch_num: int,
                    progress_bar_disable: bool = False,
                    use_amp: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean
        :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered
        :param epoch_num: the epoch number that is being trained
        :param progress_bar_disable: if True, disables the progress bar
        :return: a list of statistics for batches where statistics were computed
        """

        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)
        loop = tqdm(train_loader, disable=progress_bar_disable)

        scaler = None
        if use_amp:
            scaler = torch.cuda.amp.GradScaler()

        train_n_correct, train_n_total = None, None
        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)
        model.train()
        for batch_idx, (x, y_truth) in enumerate(loop):
            x = x.to(self.device)
            y_truth = y_truth.to(self.device)
            # if use_amp:
            #     x = x.half()
            #     y_truth = y_truth.half()

            # put network into training mode & zero out previous gradient computations
            self.optimizer.zero_grad()

            # get predictions based on input & weights learned so far
            if use_amp:
                with torch.cuda.amp.autocast():
                    y_hat = model(x)
                    # compute metrics
                    batch_train_loss = self._eval_loss_function(y_hat, y_truth)
            else:
                y_hat = model(x)
                # compute metrics
                batch_train_loss = self._eval_loss_function(y_hat, y_truth)

            sum_batchmean_train_loss += batch_train_loss.item()

            running_train_acc, train_n_total, train_n_correct = _running_eval_acc(
                y_hat,
                y_truth,
                n_total=train_n_total,
                n_correct=train_n_correct,
                soft_to_hard_fn=self.soft_to_hard_fn,
                soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs)

            if np.isnan(sum_batchmean_train_loss) or np.isnan(
                    running_train_acc):
                _save_nandata(x, y_hat, y_truth, batch_train_loss,
                              sum_batchmean_train_loss, running_train_acc,
                              train_n_total, train_n_correct, model)

            # compute gradient
            if use_amp:
                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(batch_train_loss).backward()
            else:
                batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if use_amp:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(self.optimizer)

                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val,
                        **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(
                        model.parameters(),
                        self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)

            if use_amp:
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(self.optimizer)
                # Updates the scale for next iteration.
                scaler.update()
            else:
                self.optimizer.step()

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1,
                                                      self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            # report batch statistics to tensorflow
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-train_loss',
                        batch_train_loss.item(),
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-running_train_acc',
                        running_train_acc,
                        global_step=batch_num)
                except:
                    # TODO: catch specific expcetions
                    pass

            if batch_idx % self.num_batches_per_logmsg == 0:
                logger.info(
                    '{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'
                    .format(pid, epoch_num, batch_idx * len(x),
                            train_dataset_len, 100. * batch_idx / num_batches,
                            batch_train_loss.item(), running_train_acc))

        train_stats = EpochTrainStatistics(
            running_train_acc, sum_batchmean_train_loss / float(num_batches))

        # if we have validation data, we compute on the validation dataset
        num_val_batches_clean = len(val_clean_loader)
        if num_val_batches_clean > 0:
            logger.info('Running Validation on Clean Data')
            running_val_clean_acc, _, _, val_clean_loss = \
                _eval_acc(val_clean_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info("No dataset computed for validation on clean dataset!")
            running_val_clean_acc = None
            val_clean_loss = None

        num_val_batches_triggered = len(val_triggered_loader)
        if num_val_batches_triggered > 0:
            logger.info('Running Validation on Triggered Data')
            running_val_triggered_acc, _, _, val_triggered_loss = \
                _eval_acc(val_triggered_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info(
                "No dataset computed for validation on triggered dataset!")
            running_val_triggered_acc = None
            val_triggered_loss = None

        validation_stats = EpochValidationStatistics(
            running_val_clean_acc, val_clean_loss, running_val_triggered_acc,
            val_triggered_loss)
        if num_val_batches_clean > 0:
            logger.info(
                '{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'
                .format(pid, epoch_num, val_clean_loss, running_val_clean_acc))
        if num_val_batches_triggered > 0:
            logger.info(
                '{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'
                .format(pid, epoch_num, val_triggered_loss,
                        running_val_triggered_acc))

        if self.tb_writer:
            try:
                batch_num = int((epoch_num + 1) * num_batches)
                if num_val_batches_clean > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val-loss',
                        val_clean_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-clean-val_acc',
                        running_val_clean_acc,
                        global_step=batch_num)
                if num_val_batches_triggered > 0:
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val-loss',
                        val_triggered_loss,
                        global_step=batch_num)
                    self.tb_writer.add_scalar(
                        self.optimizer_cfg.reporting_cfg.experiment_name +
                        '-triggered-val_acc',
                        running_val_triggered_acc,
                        global_step=batch_num)
            except:
                pass

        # update the lr-scheduler if necessary
        if self.lr_scheduler is not None:
            if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None:
                self.lr_scheduler.step()
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower(
            ) == 'val_acc':
                val_acc = validation_stats.get_val_acc()
                if val_acc is not None:
                    self.lr_scheduler.step(val_acc)
                else:
                    msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower(
            ) == 'val_loss':
                val_loss = validation_stats.get_val_loss()
                if val_loss is not None:
                    self.lr_scheduler.step(val_loss)
                else:
                    msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            else:
                msg = "Unknown mode for calling lr_scheduler!"
                logger.error(msg)
                raise ValueError(msg)

        return train_stats, validation_stats
Exemplo n.º 36
0
 def __dir__(self):
     return sorted(Module.__dir__(self) + self._method_names())
Exemplo n.º 37
0
    def tie_encoder_to_decoder_recursively(
        decoder_pointer: nn.Module,
        encoder_pointer: nn.Module,
        module_name: str,
        uninitialized_encoder_weights: List[str],
        depth=0,
    ):
        assert isinstance(decoder_pointer, nn.Module) and isinstance(
            encoder_pointer, nn.Module
        ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
        if hasattr(decoder_pointer, "weight"):
            assert hasattr(encoder_pointer, "weight")
            encoder_pointer.weight = decoder_pointer.weight
            if hasattr(decoder_pointer, "bias"):
                assert hasattr(encoder_pointer, "bias")
                encoder_pointer.bias = decoder_pointer.bias
            return

        encoder_modules = encoder_pointer._modules
        decoder_modules = decoder_pointer._modules

        # print("Encoder modules", " ".join([n for n in encoder_modules.keys()]))
        # print("Decoder modules", " ".join([n for n in decoder_modules.keys()]))

        if len(decoder_modules) > 0:
            # print("len(decoder_modules)", len(decoder_modules))
            assert (
                len(encoder_modules) > 0
            ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"

            all_encoder_weights = set([
                module_name + "/" + sub_name
                for sub_name in encoder_modules.keys()
            ])
            encoder_layer_pos = 0
            for name, module in decoder_modules.items():
                if name.isdigit():
                    # print("name is digit", name)
                    encoder_name = str(int(name) + encoder_layer_pos)
                    decoder_name = name

                    # print("encoder_name", encoder_name)
                    # print("decoder_name", decoder_name)

                    if not isinstance(
                            decoder_modules[decoder_name],
                            type(encoder_modules[encoder_name])) and len(
                                encoder_modules) != len(decoder_modules):
                        # this can happen if the name corresponds to the position in a list module list of layers
                        # in this case the decoder has added a cross-attention that the encoder does not have
                        # thus skip this step and substract one layer pos from encoder
                        encoder_layer_pos -= 1
                        continue
                elif name not in encoder_modules:
                    continue
                elif depth > 500:
                    raise ValueError(
                        "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a "
                        "circular dependency between two or more `nn.Modules` of your model. "
                    )
                else:
                    # print("else")
                    decoder_name = encoder_name = name

                    # print("decoder_name = encoder_name = name")
                    # print("decoder_name:", decoder_name)

                tie_encoder_to_decoder_recursively(
                    decoder_modules[decoder_name],
                    encoder_modules[encoder_name],
                    module_name + "/" + name,
                    uninitialized_encoder_weights,
                    depth=depth + 1,
                )
                all_encoder_weights.remove(module_name + "/" + encoder_name)

            uninitialized_encoder_weights += list(all_encoder_weights)