def train(model: Process, args: Namespace, loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader) -> Tuple[Process, dict]: """Train a model. Args: model: Model to be trained. args: Arguments for training. loader: The dataset for training. val_loader: The dataset for evaluation. test_loader: The dataset for testing Returns: Best trained model from early stopping. """ if args.include_poisson: processes = model.processes.keys() modules = [] for p in processes: if p != 'poisson': modules.append(getattr(model, p)) optimizer = Adam([{ 'params': m.parameters() } for m in modules] + [{ 'params': model.alpha }] + [{ 'params': model.poisson.parameters(), 'lr': args.lr_poisson_rate_init }], lr=args.lr_rate_init) else: optimizer = Adam(model.parameters(), lr=args.lr_rate_init) lr_scheduler = create_lr_scheduler(optimizer=optimizer, args=args) parameters = dict(model.named_parameters()) lr_wait, cnt_wait, best_loss, best_epoch = 0, 0, 1e9, 0 best_state = deepcopy(model.state_dict()) train_dur, val_dur, images_urls = list(), list(), dict() images_urls['intensity'] = list() images_urls['src_attn'] = list() images_urls['tgt_attn'] = list() epochs = range(args.train_epochs) if args.verbose: epochs = tqdm(epochs) for epoch in epochs: t0, _ = time.time(), model.train() if args.lr_scheduler != 'plateau': lr_scheduler.step() for i, batch in enumerate((tqdm(loader)) if args.verbose else loader): optimizer.zero_grad() loss, loss_mask, _ = get_loss(model, batch=batch, args=args) # [B] loss = loss * loss_mask loss = th.sum(loss) check_tensor(loss) loss.backward() optimizer.step() train_dur.append(time.time() - t0) train_metrics = evaluate(model, args=args, loader=loader) val_metrics = evaluate(model, args=args, loader=val_loader) val_dur.append(val_metrics["dur"]) if args.lr_scheduler == 'plateau': lr_scheduler.step(metrics=val_metrics["loss"]) new_best = val_metrics["loss"] < best_loss if args.loss_relative_tolerance is not None: abs_rel_loss_diff = (val_metrics["loss"] - best_loss) / best_loss abs_rel_loss_diff = abs(abs_rel_loss_diff) above_numerical_tolerance = (abs_rel_loss_diff > args.loss_relative_tolerance) new_best = new_best and above_numerical_tolerance if new_best: best_loss, best_t = val_metrics["loss"], epoch cnt_wait, lr_wait = 0, 0 best_state = deepcopy(model.state_dict()) else: cnt_wait, lr_wait = cnt_wait + 1, lr_wait + 1 if cnt_wait == args.patience: print("Early stopping!") break if epoch % args.save_model_freq == 0 and parsed_args.use_mlflow: current_state = deepcopy(model.state_dict()) model.load_state_dict(best_state) epoch_str = get_epoch_str(epoch=epoch, max_epochs=args.train_epochs) mlflow.pytorch.log_model(model, "models/epoch_" + epoch_str) images_urls = log_figures(model=model, test_loader=test_loader, epoch=epoch, args=args, images_urls=images_urls) model.load_state_dict(current_state) lr = optimizer.param_groups[0]['lr'] train_metrics["lr"] = lr if args.include_poisson: lr_poisson = optimizer.param_groups[-1]['lr'] else: lr_poisson = lr status = get_status(args=args, epoch=epoch, lr=lr, lr_poisson=lr_poisson, parameters=parameters, train_loss=train_metrics["loss"], val_metrics=val_metrics, cnt_wait=cnt_wait) print(status) if args.use_mlflow and epoch % args.logging_frequency == 0: loss_metrics = { "lr": train_metrics["lr"], "train_loss": train_metrics["loss"], "train_loss_per_time": train_metrics["loss_per_time"], "valid_loss": val_metrics["loss"], "valid_loss_per_time": val_metrics["loss_per_time"] } log_metrics(model=model, metrics=loss_metrics, val_metrics=val_metrics, args=args, epoch=epoch) model.load_state_dict(best_state) return model, images_urls
def forward( self, events: Events, query: th.Tensor, prev_times: th.Tensor, prev_times_idxs: th.Tensor, pos_delta_mask: th.Tensor, is_event: th.Tensor, representations: th.Tensor, representations_mask: Optional[th.Tensor] = None, artifacts: Optional[dict] = None ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]: """Compute the intensities for each query time given event representations. Args: events: [B,L] Times and labels of events. query: [B,T] Times to evaluate the intensity function. prev_times: [B,T] Times of events directly preceding queries. prev_times_idxs: [B,T] Indexes of times of events directly preceding queries. These indexes are of window-prepended events. pos_delta_mask: [B,T] A mask indicating if the time difference `query - prev_times` is strictly positive. is_event: [B,T] A mask indicating whether the time given by `prev_times_idxs` corresponds to an event or not (a 1 indicates an event and a 0 indicates a window boundary). representations: [B,L+1,D] Representations of each event. representations_mask: [B,L+1] Mask indicating which representations are well-defined. If `None`, there is no mask. Defaults to `None`. artifacts: A dictionary of whatever else you might want to return. Returns: log_intensity: [B,T,M] The intensities for each query time for each mark (class). intensity_integrals: [B,T,M] The integral of the intensity from the most recent event to the query time for each mark. intensities_mask: [B,T] Which intensities are valid for further computation based on e.g. sufficient history available. """ marked_log_intensity, intensity_mask, artifacts = self.log_intensity( events=events, query=query, prev_times=prev_times, prev_times_idxs=prev_times_idxs, pos_delta_mask=pos_delta_mask, is_event=is_event, representations=representations, representations_mask=representations_mask, artifacts=artifacts) # [B,T,M], [B,T], dict # Create Monte Carlo samples and sort them n_est = int(self.mc_prop_est) mc_times_samples = th.rand( query.shape[0], query.shape[1], n_est, device=query.device) * \ (query - prev_times).unsqueeze(-1) + prev_times.unsqueeze(-1) mc_times_samples = th.sort(mc_times_samples, dim=-1).values mc_times_samples = mc_times_samples.reshape(mc_times_samples.shape[0], -1) # [B, TxN] mc_marked_log_intensity, _, _ = self.log_intensity( events=events, query=mc_times_samples, prev_times=th.repeat_interleave(prev_times, n_est, dim=-1), prev_times_idxs=th.repeat_interleave(prev_times_idxs, n_est, dim=-1), pos_delta_mask=th.repeat_interleave(pos_delta_mask, n_est, dim=-1), is_event=th.repeat_interleave(is_event, n_est, dim=-1), representations=representations, representations_mask=representations_mask) # [B,TxN,M] mc_marked_log_intensity = mc_marked_log_intensity.reshape( query.shape[0], query.shape[1], n_est, self.marks) # [B,T,N,M] mc_marked_log_intensity = mc_marked_log_intensity * \ intensity_mask.unsqueeze(-1).unsqueeze(-1) # [B,T,N,M] marked_intensity_mc = th.exp(mc_marked_log_intensity) intensity_integrals = (query - prev_times).unsqueeze(-1) * \ marked_intensity_mc.sum(-2) / float(n_est) # [B,T,M] check_tensor(marked_log_intensity) check_tensor(intensity_integrals * intensity_mask.unsqueeze(-1), positive=True) return (marked_log_intensity, intensity_integrals, intensity_mask, artifacts) # [B,T,M], [B,T,M], [B,T], Dict
def log_intensity( self, events: Events, query: th.Tensor, prev_times: th.Tensor, prev_times_idxs: th.Tensor, pos_delta_mask: th.Tensor, is_event: th.Tensor, representations: th.Tensor, representations_mask: Optional[th.Tensor] = None, artifacts: Optional[dict] = None ) -> Tuple[th.Tensor, th.Tensor, Dict]: """Compute the log_intensity and a mask Args: events: [B,L] Times and labels of events. query: [B,T] Times to evaluate the intensity function. prev_times: [B,T] Times of events directly preceding queries. prev_times_idxs: [B,T] Indexes of times of events directly preceding queries. These indexes are of window-prepended events. pos_delta_mask: [B,T] A mask indicating if the time difference `query - prev_times` is strictly positive. is_event: [B,T] A mask indicating whether the time given by `prev_times_idxs` corresponds to an event or not (a 1 indicates an event and a 0 indicates a window boundary). representations: [B,L+1,D] Representations of each event. representations_mask: [B,L+1] Mask indicating which representations are well-defined. If `None`, there is no mask. Defaults to `None`. artifacts: A dictionary of whatever else you might want to return. Returns: log_intensity: [B,T,M] The intensities for each query time for each mark (class). intensities_mask: [B,T] Which intensities are valid for further computation based on e.g. sufficient history available. """ batch_size, query_length = query.size() query_representations, intensity_mask = self.get_query_representations( events=events, query=query, prev_times=prev_times, prev_times_idxs=prev_times_idxs, pos_delta_mask=pos_delta_mask, is_event=is_event, representations=representations, representations_mask=representations_mask) # [B,T,D], [B,T] history_representations = take_3_by_2(representations, index=prev_times_idxs) query = query * intensity_mask prev_times = prev_times * intensity_mask h_seq = th.zeros(query_length, batch_size, self.units_rnn, dtype=th.float, device=representations.device) h_d = th.zeros(batch_size, self.units_rnn, dtype=th.float, device=representations.device) c_d = th.zeros(batch_size, self.units_rnn, dtype=th.float, device=representations.device) c_bar = th.zeros(batch_size, self.units_rnn, dtype=th.float, device=representations.device) for t in range(query_length): c, new_c_bar, o_t, delta_t = self.recurrence( history_representations[:, t], h_d, c_d, c_bar) new_c_d, new_h_d = self.decay(c, new_c_bar, o_t, delta_t, query[:, t] - prev_times[:, t]) mask = intensity_mask[:, t].unsqueeze(-1) h_d = new_h_d * mask + h_d * (1. - mask) c_d = new_c_d * mask + c_d * (1. - mask) c_bar = new_c_bar * mask + c_bar * (1. - mask) h_seq[t] = h_d hidden = h_seq.transpose(0, 1) hidden = F.normalize(hidden, dim=-1, p=2) outputs = self.mlp(hidden) # [B,L,output_size] check_tensor(outputs, positive=True, strict=True) log = Log.apply outputs = log(outputs) return outputs, intensity_mask, artifacts
def forward( self, events: Events, query: th.Tensor, prev_times: th.Tensor, prev_times_idxs: th.Tensor, pos_delta_mask: th.Tensor, is_event: th.Tensor, representations: th.Tensor, representations_mask: Optional[th.Tensor] = None, artifacts: Optional[bool] = None ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]: """Compute the intensities for each query time given event representations. Args: events: [B,L] Times and labels of events. query: [B,T] Times to evaluate the intensity function. prev_times: [B,T] Times of events directly preceding queries. prev_times_idxs: [B,T] Indexes of times of events directly preceding queries. These indexes are of window-prepended events. pos_delta_mask: [B,T] A mask indicating if the time difference `query - prev_times` is strictly positive. is_event: [B,T] A mask indicating whether the time given by `prev_times_idxs` corresponds to an event or not (a 1 indicates an event and a 0 indicates a window boundary). representations: [B,L+1,D] Representations of each event. representations_mask: [B,L+1] Mask indicating which representations are well-defined. If `None`, there is no mask. Defaults to `None`. artifacts: A dictionary of whatever else you might want to return. Returns: log_intensity: [B,T,M] The intensities for each query time for each mark (class). intensity_integrals: [B,T,M] The integral of the intensity from the most recent event to the query time for each mark. intensities_mask: [B,T] Which intensities are valid for further computation based on e.g. sufficient history available. artifacts: Some measures """ # Add grads for query to compute derivative query.requires_grad = True intensity_integrals_q, intensity_mask_q, artifacts = \ self.cum_intensity( events=events, query=query, prev_times=prev_times, prev_times_idxs=prev_times_idxs, pos_delta_mask=pos_delta_mask, is_event=is_event, representations=representations, representations_mask=representations_mask, artifacts=artifacts, update_running_stats=False) # Remove masked values and add epsilon for stability intensity_integrals_q = \ intensity_integrals_q * intensity_mask_q.unsqueeze(-1) # Optional zero substraction if self.do_zero_subtraction: (intensity_integrals_z, intensity_mask_z, artifacts_zero) = self.cum_intensity( events=events, query=prev_times, prev_times=prev_times, prev_times_idxs=prev_times_idxs, pos_delta_mask=pos_delta_mask, is_event=is_event, representations=representations, representations_mask=representations_mask, artifacts=artifacts) intensity_integrals_z = \ intensity_integrals_z * intensity_mask_z.unsqueeze(-1) intensity_integrals_q = th.clamp( intensity_integrals_q - intensity_integrals_z, min=0.) + intensity_integrals_z intensity_integrals_q = intensity_integrals_q + epsilon( eps=1e-3, dtype=intensity_integrals_q.dtype, device=intensity_integrals_q.device) * query.unsqueeze(-1) if self.model_log_cm: intensity_integrals = subtract_exp(intensity_integrals_q, intensity_integrals_z) else: intensity_integrals = \ intensity_integrals_q - intensity_integrals_z intensity_mask = intensity_mask_q * intensity_mask_z else: intensity_integrals_q = intensity_integrals_q + epsilon( eps=1e-3, dtype=intensity_integrals_q.dtype, device=intensity_integrals_q.device) * query.unsqueeze(-1) intensity_mask = intensity_mask_q if self.model_log_cm: intensity_integrals = th.exp(intensity_integrals_q) else: intensity_integrals = intensity_integrals_q check_tensor(intensity_integrals * intensity_mask.unsqueeze(-1), positive=True) # Compute derivative of the integral grad_outputs = th.zeros_like(intensity_integrals_q, requires_grad=True) grad_inputs = th.autograd.grad(outputs=intensity_integrals_q, inputs=query, grad_outputs=grad_outputs, retain_graph=True, create_graph=True)[0] marked_intensity = th.autograd.grad( outputs=grad_inputs, inputs=grad_outputs, grad_outputs=th.ones_like(grad_inputs), retain_graph=True, create_graph=True)[0] query.requires_grad = False check_tensor(marked_intensity, positive=True, strict=True) log = Log.apply if self.model_log_cm: marked_log_intensity = \ log(marked_intensity) + intensity_integrals_q else: marked_log_intensity = log(marked_intensity) artifacts_decoder = { "intensity_integrals": intensity_integrals, "marked_intensity": marked_intensity, "marked_log_intensity": marked_log_intensity, "intensity_mask": intensity_mask } if artifacts is None: artifacts = {'decoder': artifacts_decoder} else: if 'decoder' in artifacts: if 'attention_weights' in artifacts['decoder']: artifacts_decoder['attention_weights'] = \ artifacts['decoder']['attention_weights'] artifacts['decoder'] = artifacts_decoder return (marked_log_intensity, intensity_integrals, intensity_mask, artifacts) # [B,T,M], [B,T,M], [B,T], Dict
def forward( self, events: Events, query: th.Tensor, prev_times: th.Tensor, prev_times_idxs: th.LongTensor, pos_delta_mask: th.Tensor, is_event: th.Tensor, representations: th.Tensor, representations_mask: Optional[th.Tensor] = None, artifacts: Optional[dict] = None ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]: """Compute the intensities for each query time given event representations. Args: events: [B,L] Times and labels of events. query: [B,T] Times to evaluate the intensity function. prev_times: [B,T] Times of events directly preceding queries. prev_times_idxs: [B,T] Indexes of times of events directly preceding queries. These indexes are of window-prepended events. pos_delta_mask: [B,T] A mask indicating if the time difference `query - prev_times` is strictly positive. is_event: [B,T] A mask indicating whether the time given by `prev_times_idxs` corresponds to an event or not (a 1 indicates an event and a 0 indicates a window boundary). representations: [B,L+1,D] Representations of window start and each event. representations_mask: [B,L+1] Mask indicating which representations are well-defined. If `None`, there is no mask. Defaults to `None`. artifacts: A dictionary of whatever else you might want to return. Returns: log_intensity: [B,T,M] The intensities for each query time for each mark (class). intensity_integrals: [B,T,M] The integral of the intensity from the most recent event to the query time for each mark. intensities_mask: [B,T] Which intensities are valid for further computation based on e.g. sufficient history available. artifacts: A dictionary of whatever else you might want to return. """ query.requires_grad = True query_representations = take_3_by_2(representations, index=prev_times_idxs) # [B,T,D] delta_t = query - prev_times # [B,T] delta_t = delta_t.unsqueeze(-1) # [B,T,1] delta_t = th.relu(delta_t) delta_t = delta_t + (delta_t == 0).float() * epsilon( dtype=delta_t.dtype, device=delta_t.device) mu = self.mu(query_representations) # [B,T,K] std = th.exp(self.s(query_representations)) w = th.softmax(self.w(query_representations), dim=-1) if self.multi_labels: p_m = th.sigmoid( self.marks2(th.tanh(self.marks1(query_representations)))) else: p_m = th.softmax(self.marks2( th.tanh(self.marks1(query_representations))), dim=-1) cum_f = w * 0.5 * (1 + th.erf( (th.log(delta_t) - mu) / (std * math.sqrt(2)))) cum_f = th.sum(cum_f, dim=-1) one_min_cum_f = 1. - cum_f one_min_cum_f = th.relu(one_min_cum_f) + epsilon(dtype=cum_f.dtype, device=cum_f.device) f = th.autograd.grad(outputs=cum_f, inputs=query, grad_outputs=th.ones_like(cum_f), retain_graph=True, create_graph=True)[0] query.requires_grad = False f = f + epsilon(dtype=f.dtype, device=f.device) base_log_intensity = th.log(f / one_min_cum_f) marked_log_intensity = base_log_intensity.unsqueeze(dim=-1) # [B,T,1] marked_log_intensity = marked_log_intensity + th.log(p_m) # [B,T,M] base_intensity_itg = -th.log(one_min_cum_f) marked_intensity_itg = base_intensity_itg.unsqueeze(dim=-1) # [B,T,1] marked_intensity_itg = marked_intensity_itg * p_m # [B,T,M] intensity_mask = pos_delta_mask # [B,T] if representations_mask is not None: history_representations_mask = take_2_by_2( representations_mask, index=prev_times_idxs) # [B,T] intensity_mask = intensity_mask * history_representations_mask artifacts_decoder = { "base_log_intensity": base_log_intensity, "base_intensity_integral": base_intensity_itg, "mark_probability": p_m } if artifacts is None: artifacts = {'decoder': artifacts_decoder} else: artifacts['decoder'] = artifacts_decoder check_tensor(marked_log_intensity) check_tensor(marked_intensity_itg * intensity_mask.unsqueeze(-1), positive=True) return (marked_log_intensity, marked_intensity_itg, intensity_mask, artifacts) # [B,T,M], [B,T,M], [B,T], Dict
def forward( self, events: Events, query: th.Tensor, prev_times: th.Tensor, prev_times_idxs: th.LongTensor, pos_delta_mask: th.Tensor, is_event: th.Tensor, representations: th.Tensor, representations_mask: Optional[th.Tensor] = None, artifacts: Optional[dict] = None ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, Dict]: """Compute the intensities for each query time given event representations. Args: events: [B,L] Times and labels of events. query: [B,T] Times to evaluate the intensity function. prev_times: [B,T] Times of events directly preceding queries. prev_times_idxs: [B,T] Indexes of times of events directly preceding queries. These indexes are of window-prepended events. pos_delta_mask: [B,T] A mask indicating if the time difference `query - prev_times` is strictly positive. is_event: [B,T] A mask indicating whether the time given by `prev_times_idxs` corresponds to an event or not (a 1 indicates an event and a 0 indicates a window boundary). representations: [B,L+1,D] Representations of window start and each event. representations_mask: [B,L+1] Mask indicating which representations are well-defined. If `None`, there is no mask. Defaults to `None`. artifacts: A dictionary of whatever else you might want to return. Returns: log_intensity: [B,T,M] The intensities for each query time for each mark (class). intensity_integrals: [B,T,M] The integral of the intensity from the most recent event to the query time for each mark. intensities_mask: [B,T] Which intensities are valid for further computation based on e.g. sufficient history available. artifacts: A dictionary of whatever else you might want to return. """ query_representations = take_3_by_2(representations, index=prev_times_idxs) # [B,T,D] v_h_t = query_representations[:, :, 0] # [B,T] v_h_m = query_representations[:, :, 1:] # [B,T,M] w_delta_t = self.w * (query - prev_times) # [B,T] base_log_intensity = v_h_t + w_delta_t # [B,T] if self.multi_labels: p_m = th.sigmoid(v_h_m) # [B,T,M] else: p_m = th.softmax(v_h_m, dim=-1) # [B,T,M] regulariser = epsilon(dtype=p_m.dtype, device=p_m.device) p_m = p_m + regulariser marked_log_intensity = base_log_intensity.unsqueeze(dim=-1) # [B,T,1] marked_log_intensity = marked_log_intensity + th.log(p_m) # [B,T,M] intensity_mask = pos_delta_mask # [B,T] if representations_mask is not None: history_representations_mask = take_2_by_2( representations_mask, index=prev_times_idxs) # [B,T] intensity_mask = intensity_mask * history_representations_mask exp_1, exp_2 = v_h_t + w_delta_t, v_h_t # [B,T] # Avoid exponentiating to get masked infinity exp_1, exp_2 = exp_1 * intensity_mask, exp_2 * intensity_mask # [B,T] base_intensity_itg = subtract_exp(exp_1, exp_2) base_intensity_itg = base_intensity_itg / self.w # [B,T] base_intensity_itg = th.relu(base_intensity_itg) marked_intensity_itg = base_intensity_itg.unsqueeze(dim=-1) # [B,T,1] marked_intensity_itg = marked_intensity_itg * p_m # [B,T,M] artifacts_decoder = { "base_log_intensity": base_log_intensity, "base_intensity_integral": base_intensity_itg, "mark_probability": p_m } if artifacts is None: artifacts = {'decoder': artifacts_decoder} else: artifacts['decoder'] = artifacts_decoder check_tensor(marked_log_intensity) check_tensor(marked_intensity_itg * intensity_mask.unsqueeze(-1), positive=True) return (marked_log_intensity, marked_intensity_itg, intensity_mask, artifacts) # [B,T,M], [B,T,M], [B,T], Dict