Beispiel #1
0
def train(model: Process, args: Namespace, loader: DataLoader,
          val_loader: DataLoader,
          test_loader: DataLoader) -> Tuple[Process, dict]:
    """Train a model.

    Args:
        model: Model to be trained.
        args: Arguments for training.
        loader: The dataset for training.
        val_loader: The dataset for evaluation.
        test_loader: The dataset for testing

    Returns:
        Best trained model from early stopping.

    """
    if args.include_poisson:
        processes = model.processes.keys()
        modules = []
        for p in processes:
            if p != 'poisson':
                modules.append(getattr(model, p))
        optimizer = Adam([{
            'params': m.parameters()
        } for m in modules] + [{
            'params': model.alpha
        }] + [{
            'params': model.poisson.parameters(),
            'lr': args.lr_poisson_rate_init
        }],
                         lr=args.lr_rate_init)
    else:
        optimizer = Adam(model.parameters(), lr=args.lr_rate_init)
    lr_scheduler = create_lr_scheduler(optimizer=optimizer, args=args)

    parameters = dict(model.named_parameters())

    lr_wait, cnt_wait, best_loss, best_epoch = 0, 0, 1e9, 0
    best_state = deepcopy(model.state_dict())
    train_dur, val_dur, images_urls = list(), list(), dict()
    images_urls['intensity'] = list()
    images_urls['src_attn'] = list()
    images_urls['tgt_attn'] = list()

    epochs = range(args.train_epochs)
    if args.verbose:
        epochs = tqdm(epochs)

    for epoch in epochs:
        t0, _ = time.time(), model.train()

        if args.lr_scheduler != 'plateau':
            lr_scheduler.step()

        for i, batch in enumerate((tqdm(loader)) if args.verbose else loader):
            optimizer.zero_grad()
            loss, loss_mask, _ = get_loss(model, batch=batch, args=args)  # [B]
            loss = loss * loss_mask
            loss = th.sum(loss)
            check_tensor(loss)
            loss.backward()
            optimizer.step()
        train_dur.append(time.time() - t0)

        train_metrics = evaluate(model, args=args, loader=loader)
        val_metrics = evaluate(model, args=args, loader=val_loader)

        val_dur.append(val_metrics["dur"])

        if args.lr_scheduler == 'plateau':
            lr_scheduler.step(metrics=val_metrics["loss"])

        new_best = val_metrics["loss"] < best_loss
        if args.loss_relative_tolerance is not None:
            abs_rel_loss_diff = (val_metrics["loss"] - best_loss) / best_loss
            abs_rel_loss_diff = abs(abs_rel_loss_diff)
            above_numerical_tolerance = (abs_rel_loss_diff >
                                         args.loss_relative_tolerance)
            new_best = new_best and above_numerical_tolerance

        if new_best:
            best_loss, best_t = val_metrics["loss"], epoch
            cnt_wait, lr_wait = 0, 0
            best_state = deepcopy(model.state_dict())
        else:
            cnt_wait, lr_wait = cnt_wait + 1, lr_wait + 1

        if cnt_wait == args.patience:
            print("Early stopping!")
            break

        if epoch % args.save_model_freq == 0 and parsed_args.use_mlflow:
            current_state = deepcopy(model.state_dict())
            model.load_state_dict(best_state)
            epoch_str = get_epoch_str(epoch=epoch,
                                      max_epochs=args.train_epochs)

            mlflow.pytorch.log_model(model, "models/epoch_" + epoch_str)
            images_urls = log_figures(model=model,
                                      test_loader=test_loader,
                                      epoch=epoch,
                                      args=args,
                                      images_urls=images_urls)
            model.load_state_dict(current_state)

        lr = optimizer.param_groups[0]['lr']
        train_metrics["lr"] = lr
        if args.include_poisson:
            lr_poisson = optimizer.param_groups[-1]['lr']
        else:
            lr_poisson = lr

        status = get_status(args=args,
                            epoch=epoch,
                            lr=lr,
                            lr_poisson=lr_poisson,
                            parameters=parameters,
                            train_loss=train_metrics["loss"],
                            val_metrics=val_metrics,
                            cnt_wait=cnt_wait)
        print(status)

        if args.use_mlflow and epoch % args.logging_frequency == 0:
            loss_metrics = {
                "lr": train_metrics["lr"],
                "train_loss": train_metrics["loss"],
                "train_loss_per_time": train_metrics["loss_per_time"],
                "valid_loss": val_metrics["loss"],
                "valid_loss_per_time": val_metrics["loss_per_time"]
            }
            log_metrics(model=model,
                        metrics=loss_metrics,
                        val_metrics=val_metrics,
                        args=args,
                        epoch=epoch)

    model.load_state_dict(best_state)
    return model, images_urls
Beispiel #2
0
    def forward(
        self,
        events: Events,
        query: th.Tensor,
        prev_times: th.Tensor,
        prev_times_idxs: th.Tensor,
        pos_delta_mask: th.Tensor,
        is_event: th.Tensor,
        representations: th.Tensor,
        representations_mask: Optional[th.Tensor] = None,
        artifacts: Optional[dict] = None
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]:
        """Compute the intensities for each query time given event
        representations.

        Args:
            events: [B,L] Times and labels of events.
            query: [B,T] Times to evaluate the intensity function.
            prev_times: [B,T] Times of events directly preceding queries.
            prev_times_idxs: [B,T] Indexes of times of events directly
                preceding queries. These indexes are of window-prepended
                events.
            pos_delta_mask: [B,T] A mask indicating if the time difference
                `query - prev_times` is strictly positive.
            is_event: [B,T] A mask indicating whether the time given by
                `prev_times_idxs` corresponds to an event or not (a 1 indicates
                an event and a 0 indicates a window boundary).
            representations: [B,L+1,D] Representations of each event.
            representations_mask: [B,L+1] Mask indicating which representations
                are well-defined. If `None`, there is no mask. Defaults to
                `None`.
            artifacts: A dictionary of whatever else you might want to return.

        Returns:
            log_intensity: [B,T,M] The intensities for each query time for
                each mark (class).
            intensity_integrals: [B,T,M] The integral of the intensity from
                the most recent event to the query time for each mark.
            intensities_mask: [B,T]   Which intensities are valid for further
                computation based on e.g. sufficient history available.

        """
        marked_log_intensity, intensity_mask, artifacts = self.log_intensity(
            events=events,
            query=query,
            prev_times=prev_times,
            prev_times_idxs=prev_times_idxs,
            pos_delta_mask=pos_delta_mask,
            is_event=is_event,
            representations=representations,
            representations_mask=representations_mask,
            artifacts=artifacts)  # [B,T,M], [B,T], dict

        # Create Monte Carlo samples and sort them
        n_est = int(self.mc_prop_est)
        mc_times_samples = th.rand(
            query.shape[0], query.shape[1], n_est, device=query.device) * \
            (query - prev_times).unsqueeze(-1) + prev_times.unsqueeze(-1)
        mc_times_samples = th.sort(mc_times_samples, dim=-1).values
        mc_times_samples = mc_times_samples.reshape(mc_times_samples.shape[0],
                                                    -1)  # [B, TxN]

        mc_marked_log_intensity, _, _ = self.log_intensity(
            events=events,
            query=mc_times_samples,
            prev_times=th.repeat_interleave(prev_times, n_est, dim=-1),
            prev_times_idxs=th.repeat_interleave(prev_times_idxs,
                                                 n_est,
                                                 dim=-1),
            pos_delta_mask=th.repeat_interleave(pos_delta_mask, n_est, dim=-1),
            is_event=th.repeat_interleave(is_event, n_est, dim=-1),
            representations=representations,
            representations_mask=representations_mask)  # [B,TxN,M]

        mc_marked_log_intensity = mc_marked_log_intensity.reshape(
            query.shape[0], query.shape[1], n_est, self.marks)  # [B,T,N,M]
        mc_marked_log_intensity = mc_marked_log_intensity * \
            intensity_mask.unsqueeze(-1).unsqueeze(-1)  # [B,T,N,M]
        marked_intensity_mc = th.exp(mc_marked_log_intensity)
        intensity_integrals = (query - prev_times).unsqueeze(-1) * \
            marked_intensity_mc.sum(-2) / float(n_est)  # [B,T,M]

        check_tensor(marked_log_intensity)
        check_tensor(intensity_integrals * intensity_mask.unsqueeze(-1),
                     positive=True)
        return (marked_log_intensity, intensity_integrals, intensity_mask,
                artifacts)  # [B,T,M], [B,T,M], [B,T], Dict
Beispiel #3
0
    def log_intensity(
            self,
            events: Events,
            query: th.Tensor,
            prev_times: th.Tensor,
            prev_times_idxs: th.Tensor,
            pos_delta_mask: th.Tensor,
            is_event: th.Tensor,
            representations: th.Tensor,
            representations_mask: Optional[th.Tensor] = None,
            artifacts: Optional[dict] = None
    ) -> Tuple[th.Tensor, th.Tensor, Dict]:
        """Compute the log_intensity and a mask
        Args:
            events: [B,L] Times and labels of events.
            query: [B,T] Times to evaluate the intensity function.
            prev_times: [B,T] Times of events directly preceding queries.
            prev_times_idxs: [B,T] Indexes of times of events directly
                preceding queries. These indexes are of window-prepended
                events.
            pos_delta_mask: [B,T] A mask indicating if the time difference
                `query - prev_times` is strictly positive.
            is_event: [B,T] A mask indicating whether the time given by
                `prev_times_idxs` corresponds to an event or not (a 1 indicates
                an event and a 0 indicates a window boundary).
            representations: [B,L+1,D] Representations of each event.
            representations_mask: [B,L+1] Mask indicating which representations
                are well-defined. If `None`, there is no mask. Defaults to
                `None`.
            artifacts: A dictionary of whatever else you might want to return.
        Returns:
            log_intensity: [B,T,M] The intensities for each query time for
                each mark (class).
            intensities_mask: [B,T]   Which intensities are valid for further
                computation based on e.g. sufficient history available.
        """
        batch_size, query_length = query.size()
        query_representations, intensity_mask = self.get_query_representations(
            events=events,
            query=query,
            prev_times=prev_times,
            prev_times_idxs=prev_times_idxs,
            pos_delta_mask=pos_delta_mask,
            is_event=is_event,
            representations=representations,
            representations_mask=representations_mask)  # [B,T,D], [B,T]

        history_representations = take_3_by_2(representations,
                                              index=prev_times_idxs)
        query = query * intensity_mask
        prev_times = prev_times * intensity_mask

        h_seq = th.zeros(query_length,
                         batch_size,
                         self.units_rnn,
                         dtype=th.float,
                         device=representations.device)
        h_d = th.zeros(batch_size,
                       self.units_rnn,
                       dtype=th.float,
                       device=representations.device)
        c_d = th.zeros(batch_size,
                       self.units_rnn,
                       dtype=th.float,
                       device=representations.device)
        c_bar = th.zeros(batch_size,
                         self.units_rnn,
                         dtype=th.float,
                         device=representations.device)

        for t in range(query_length):
            c, new_c_bar, o_t, delta_t = self.recurrence(
                history_representations[:, t], h_d, c_d, c_bar)
            new_c_d, new_h_d = self.decay(c, new_c_bar, o_t, delta_t,
                                          query[:, t] - prev_times[:, t])
            mask = intensity_mask[:, t].unsqueeze(-1)
            h_d = new_h_d * mask + h_d * (1. - mask)
            c_d = new_c_d * mask + c_d * (1. - mask)
            c_bar = new_c_bar * mask + c_bar * (1. - mask)
            h_seq[t] = h_d

        hidden = h_seq.transpose(0, 1)
        hidden = F.normalize(hidden, dim=-1, p=2)

        outputs = self.mlp(hidden)  # [B,L,output_size]
        check_tensor(outputs, positive=True, strict=True)
        log = Log.apply
        outputs = log(outputs)

        return outputs, intensity_mask, artifacts
Beispiel #4
0
    def forward(
        self,
        events: Events,
        query: th.Tensor,
        prev_times: th.Tensor,
        prev_times_idxs: th.Tensor,
        pos_delta_mask: th.Tensor,
        is_event: th.Tensor,
        representations: th.Tensor,
        representations_mask: Optional[th.Tensor] = None,
        artifacts: Optional[bool] = None
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]:
        """Compute the intensities for each query time given event
        representations.

        Args:
            events: [B,L] Times and labels of events.
            query: [B,T] Times to evaluate the intensity function.
            prev_times: [B,T] Times of events directly preceding queries.
            prev_times_idxs: [B,T] Indexes of times of events directly
                preceding queries. These indexes are of window-prepended
                events.
            pos_delta_mask: [B,T] A mask indicating if the time difference
                `query - prev_times` is strictly positive.
            is_event: [B,T] A mask indicating whether the time given by
                `prev_times_idxs` corresponds to an event or not (a 1 indicates
                an event and a 0 indicates a window boundary).
            representations: [B,L+1,D] Representations of each event.
            representations_mask: [B,L+1] Mask indicating which representations
                are well-defined. If `None`, there is no mask. Defaults to
                `None`.
            artifacts: A dictionary of whatever else you might want to return.

        Returns:
            log_intensity: [B,T,M] The intensities for each query time for
                each mark (class).
            intensity_integrals: [B,T,M] The integral of the intensity from
                the most recent event to the query time for each mark.
            intensities_mask: [B,T]   Which intensities are valid for further
                computation based on e.g. sufficient history available.
            artifacts: Some measures

        """
        # Add grads for query to compute derivative
        query.requires_grad = True

        intensity_integrals_q, intensity_mask_q, artifacts = \
            self.cum_intensity(
                events=events,
                query=query,
                prev_times=prev_times,
                prev_times_idxs=prev_times_idxs,
                pos_delta_mask=pos_delta_mask,
                is_event=is_event,
                representations=representations,
                representations_mask=representations_mask,
                artifacts=artifacts,
                update_running_stats=False)

        # Remove masked values and add epsilon for stability
        intensity_integrals_q = \
            intensity_integrals_q * intensity_mask_q.unsqueeze(-1)

        # Optional zero substraction
        if self.do_zero_subtraction:
            (intensity_integrals_z, intensity_mask_z,
             artifacts_zero) = self.cum_intensity(
                 events=events,
                 query=prev_times,
                 prev_times=prev_times,
                 prev_times_idxs=prev_times_idxs,
                 pos_delta_mask=pos_delta_mask,
                 is_event=is_event,
                 representations=representations,
                 representations_mask=representations_mask,
                 artifacts=artifacts)

            intensity_integrals_z = \
                intensity_integrals_z * intensity_mask_z.unsqueeze(-1)
            intensity_integrals_q = th.clamp(
                intensity_integrals_q - intensity_integrals_z,
                min=0.) + intensity_integrals_z
            intensity_integrals_q = intensity_integrals_q + epsilon(
                eps=1e-3,
                dtype=intensity_integrals_q.dtype,
                device=intensity_integrals_q.device) * query.unsqueeze(-1)
            if self.model_log_cm:
                intensity_integrals = subtract_exp(intensity_integrals_q,
                                                   intensity_integrals_z)
            else:
                intensity_integrals = \
                    intensity_integrals_q - intensity_integrals_z
            intensity_mask = intensity_mask_q * intensity_mask_z

        else:
            intensity_integrals_q = intensity_integrals_q + epsilon(
                eps=1e-3,
                dtype=intensity_integrals_q.dtype,
                device=intensity_integrals_q.device) * query.unsqueeze(-1)
            intensity_mask = intensity_mask_q
            if self.model_log_cm:
                intensity_integrals = th.exp(intensity_integrals_q)
            else:
                intensity_integrals = intensity_integrals_q

        check_tensor(intensity_integrals * intensity_mask.unsqueeze(-1),
                     positive=True)

        # Compute derivative of the integral
        grad_outputs = th.zeros_like(intensity_integrals_q, requires_grad=True)
        grad_inputs = th.autograd.grad(outputs=intensity_integrals_q,
                                       inputs=query,
                                       grad_outputs=grad_outputs,
                                       retain_graph=True,
                                       create_graph=True)[0]
        marked_intensity = th.autograd.grad(
            outputs=grad_inputs,
            inputs=grad_outputs,
            grad_outputs=th.ones_like(grad_inputs),
            retain_graph=True,
            create_graph=True)[0]
        query.requires_grad = False

        check_tensor(marked_intensity, positive=True, strict=True)
        log = Log.apply
        if self.model_log_cm:
            marked_log_intensity = \
                log(marked_intensity) + intensity_integrals_q
        else:
            marked_log_intensity = log(marked_intensity)

        artifacts_decoder = {
            "intensity_integrals": intensity_integrals,
            "marked_intensity": marked_intensity,
            "marked_log_intensity": marked_log_intensity,
            "intensity_mask": intensity_mask
        }
        if artifacts is None:
            artifacts = {'decoder': artifacts_decoder}
        else:
            if 'decoder' in artifacts:
                if 'attention_weights' in artifacts['decoder']:
                    artifacts_decoder['attention_weights'] = \
                        artifacts['decoder']['attention_weights']
            artifacts['decoder'] = artifacts_decoder

        return (marked_log_intensity, intensity_integrals, intensity_mask,
                artifacts)  # [B,T,M], [B,T,M], [B,T], Dict
    def forward(
        self,
        events: Events,
        query: th.Tensor,
        prev_times: th.Tensor,
        prev_times_idxs: th.LongTensor,
        pos_delta_mask: th.Tensor,
        is_event: th.Tensor,
        representations: th.Tensor,
        representations_mask: Optional[th.Tensor] = None,
        artifacts: Optional[dict] = None
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]:
        """Compute the intensities for each query time given event
        representations.

        Args:
            events: [B,L] Times and labels of events.
            query: [B,T] Times to evaluate the intensity function.
            prev_times: [B,T] Times of events directly preceding queries.
            prev_times_idxs: [B,T] Indexes of times of events directly
                preceding queries. These indexes are of window-prepended
                events.
            pos_delta_mask: [B,T] A mask indicating if the time difference
                `query - prev_times` is strictly positive.
            is_event: [B,T] A mask indicating whether the time given by
                `prev_times_idxs` corresponds to an event or not (a 1 indicates
                an event and a 0 indicates a window boundary).
            representations: [B,L+1,D] Representations of window start and
                each event.
            representations_mask: [B,L+1] Mask indicating which representations
                are well-defined. If `None`, there is no mask. Defaults to
                `None`.
            artifacts: A dictionary of whatever else you might want to return.

        Returns:
            log_intensity: [B,T,M] The intensities for each query time for
                each mark (class).
            intensity_integrals: [B,T,M] The integral of the intensity from
                the most recent event to the query time for each mark.
            intensities_mask: [B,T] Which intensities are valid for further
                computation based on e.g. sufficient history available.
            artifacts: A dictionary of whatever else you might want to return.

        """
        query.requires_grad = True
        query_representations = take_3_by_2(representations,
                                            index=prev_times_idxs)  # [B,T,D]
        delta_t = query - prev_times  # [B,T]
        delta_t = delta_t.unsqueeze(-1)  # [B,T,1]
        delta_t = th.relu(delta_t)
        delta_t = delta_t + (delta_t == 0).float() * epsilon(
            dtype=delta_t.dtype, device=delta_t.device)

        mu = self.mu(query_representations)  # [B,T,K]
        std = th.exp(self.s(query_representations))
        w = th.softmax(self.w(query_representations), dim=-1)
        if self.multi_labels:
            p_m = th.sigmoid(
                self.marks2(th.tanh(self.marks1(query_representations))))
        else:
            p_m = th.softmax(self.marks2(
                th.tanh(self.marks1(query_representations))),
                             dim=-1)

        cum_f = w * 0.5 * (1 + th.erf(
            (th.log(delta_t) - mu) / (std * math.sqrt(2))))
        cum_f = th.sum(cum_f, dim=-1)
        one_min_cum_f = 1. - cum_f
        one_min_cum_f = th.relu(one_min_cum_f) + epsilon(dtype=cum_f.dtype,
                                                         device=cum_f.device)

        f = th.autograd.grad(outputs=cum_f,
                             inputs=query,
                             grad_outputs=th.ones_like(cum_f),
                             retain_graph=True,
                             create_graph=True)[0]
        query.requires_grad = False
        f = f + epsilon(dtype=f.dtype, device=f.device)

        base_log_intensity = th.log(f / one_min_cum_f)
        marked_log_intensity = base_log_intensity.unsqueeze(dim=-1)  # [B,T,1]
        marked_log_intensity = marked_log_intensity + th.log(p_m)  # [B,T,M]

        base_intensity_itg = -th.log(one_min_cum_f)
        marked_intensity_itg = base_intensity_itg.unsqueeze(dim=-1)  # [B,T,1]
        marked_intensity_itg = marked_intensity_itg * p_m  # [B,T,M]

        intensity_mask = pos_delta_mask  # [B,T]
        if representations_mask is not None:
            history_representations_mask = take_2_by_2(
                representations_mask, index=prev_times_idxs)  # [B,T]
            intensity_mask = intensity_mask * history_representations_mask

        artifacts_decoder = {
            "base_log_intensity": base_log_intensity,
            "base_intensity_integral": base_intensity_itg,
            "mark_probability": p_m
        }
        if artifacts is None:
            artifacts = {'decoder': artifacts_decoder}
        else:
            artifacts['decoder'] = artifacts_decoder

        check_tensor(marked_log_intensity)
        check_tensor(marked_intensity_itg * intensity_mask.unsqueeze(-1),
                     positive=True)
        return (marked_log_intensity, marked_intensity_itg, intensity_mask,
                artifacts)  # [B,T,M], [B,T,M], [B,T], Dict
Beispiel #6
0
    def forward(
        self,
        events: Events,
        query: th.Tensor,
        prev_times: th.Tensor,
        prev_times_idxs: th.LongTensor,
        pos_delta_mask: th.Tensor,
        is_event: th.Tensor,
        representations: th.Tensor,
        representations_mask: Optional[th.Tensor] = None,
        artifacts: Optional[dict] = None
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]:
        """Compute the intensities for each query time given event
        representations.

        Args:
            events: [B,L] Times and labels of events.
            query: [B,T] Times to evaluate the intensity function.
            prev_times: [B,T] Times of events directly preceding queries.
            prev_times_idxs: [B,T] Indexes of times of events directly
                preceding queries. These indexes are of window-prepended
                events.
            pos_delta_mask: [B,T] A mask indicating if the time difference
                `query - prev_times` is strictly positive.
            is_event: [B,T] A mask indicating whether the time given by
                `prev_times_idxs` corresponds to an event or not (a 1 indicates
                an event and a 0 indicates a window boundary).
            representations: [B,L+1,D] Representations of window start and
                each event.
            representations_mask: [B,L+1] Mask indicating which representations
                are well-defined. If `None`, there is no mask. Defaults to
                `None`.
            artifacts: A dictionary of whatever else you might want to return.

        Returns:
            log_intensity: [B,T,M] The intensities for each query time for
                each mark (class).
            intensity_integrals: [B,T,M] The integral of the intensity from
                the most recent event to the query time for each mark.
            intensities_mask: [B,T] Which intensities are valid for further
                computation based on e.g. sufficient history available.
            artifacts: A dictionary of whatever else you might want to return.

        """
        query_representations = take_3_by_2(representations,
                                            index=prev_times_idxs)  # [B,T,D]

        v_h_t = query_representations[:, :, 0]  # [B,T]
        v_h_m = query_representations[:, :, 1:]  # [B,T,M]

        w_delta_t = self.w * (query - prev_times)  # [B,T]
        base_log_intensity = v_h_t + w_delta_t  # [B,T]

        if self.multi_labels:
            p_m = th.sigmoid(v_h_m)  # [B,T,M]
        else:
            p_m = th.softmax(v_h_m, dim=-1)  # [B,T,M]
        regulariser = epsilon(dtype=p_m.dtype, device=p_m.device)
        p_m = p_m + regulariser

        marked_log_intensity = base_log_intensity.unsqueeze(dim=-1)  # [B,T,1]
        marked_log_intensity = marked_log_intensity + th.log(p_m)  # [B,T,M]

        intensity_mask = pos_delta_mask  # [B,T]
        if representations_mask is not None:
            history_representations_mask = take_2_by_2(
                representations_mask, index=prev_times_idxs)  # [B,T]
            intensity_mask = intensity_mask * history_representations_mask

        exp_1, exp_2 = v_h_t + w_delta_t, v_h_t  # [B,T]

        # Avoid exponentiating to get masked infinity
        exp_1, exp_2 = exp_1 * intensity_mask, exp_2 * intensity_mask  # [B,T]
        base_intensity_itg = subtract_exp(exp_1, exp_2)
        base_intensity_itg = base_intensity_itg / self.w  # [B,T]
        base_intensity_itg = th.relu(base_intensity_itg)

        marked_intensity_itg = base_intensity_itg.unsqueeze(dim=-1)  # [B,T,1]
        marked_intensity_itg = marked_intensity_itg * p_m  # [B,T,M]

        artifacts_decoder = {
            "base_log_intensity": base_log_intensity,
            "base_intensity_integral": base_intensity_itg,
            "mark_probability": p_m
        }
        if artifacts is None:
            artifacts = {'decoder': artifacts_decoder}
        else:
            artifacts['decoder'] = artifacts_decoder

        check_tensor(marked_log_intensity)
        check_tensor(marked_intensity_itg * intensity_mask.unsqueeze(-1),
                     positive=True)
        return (marked_log_intensity, marked_intensity_itg, intensity_mask,
                artifacts)  # [B,T,M], [B,T,M], [B,T], Dict