def get_test_events_query( batch_size=16, seq_len=16, n_queries=16, device=th.device('cpu'), dtype=th.float32): marks = 1 padding_id = -1. times = np.random.uniform( low=0.01, high=1., size=[batch_size, seq_len]).astype(np.float32) query = np.random.uniform( low=0.01, high=1., size=[batch_size, n_queries]).astype(np.float32) mask = times != padding_id times, query = th.from_numpy(times), th.from_numpy(query) times, query = times.type(dtype), query.type(dtype) mask = th.from_numpy(mask).type(times.dtype) times, query, mask = times.to(device), query.to(device), mask.to(device) window_start, window_end = get_window(times=times, window=1.) events = get_events( times=times, mask=mask, window_start=window_start, window_end=window_end) (prev_times, _), is_event, _ = get_prev_times( query=query, events=events, allow_window=True) alpha = th.from_numpy(np.array([[0.1]], dtype=np.float32)) beta = th.from_numpy(np.array([[1.0]], dtype=np.float32)) mu = th.from_numpy(np.array([0.05], dtype=np.float32)) return marks, query, events, prev_times, is_event, alpha, beta, mu
def test_setup(): padding_id = -1. times = th.Tensor([[1, 2, 3]]).type(th.float32) labels = th.Tensor([[1, 0, 0]]).type(th.long) mask = (times != padding_id).type(times.dtype).to(times.device) window_start, window_end = get_window(times=times, window=4.) events = get_events( times=times, mask=mask, labels=labels, window_start=window_start, window_end=window_end, marks=2) query = th.Tensor([[2.5]]).type(th.float32) return events, query
def get_fast_slow_results(): padding_id = -1. times = th.Tensor([[1, 2, -1., -1.]]).type(th.float32) labels = th.Tensor([[1, 0, 0, 1]]).type(th.long) mask = (times != padding_id).type(times.dtype).to(times.device) marks = 2 query = th.Tensor([[2.5, 7.]]).type(th.float32) window_start, window_end = get_window(times=times, window=4.) beta = th.Tensor([2, 1, 1, 3]).reshape(marks, marks).float() alpha = th.Tensor([1, 2, 1, 1]).reshape(marks, marks).float() mu = th.zeros(size=[marks], dtype=th.float32) + 3.00001 events = get_events( times=times, mask=mask, labels=labels, window_start=window_start, window_end=window_end, marks=2) (prev_times, _), is_event, _ = get_prev_times( query=query, events=events, allow_window=True) results_fast = decoder_fast( events=events, query=query, prev_times=prev_times, is_event=is_event, alpha=alpha, beta=beta, mu=mu, marks=marks) results_slow = decoder_slow( events=events, query=query, prev_times=prev_times, is_event=is_event, alpha=alpha, beta=beta, mu=mu, marks=marks) return results_fast, results_slow
def get_loss( model: Process, batch: Dict[str, th.Tensor], args: Namespace, eval_metrics: Optional[bool] = False, dynamic_batch_length: Optional[bool] = True, ) -> Tuple[th.Tensor, th.Tensor, Dict]: times, labels = batch["times"], batch["labels"] labels = (labels != 0).type(labels.dtype) if dynamic_batch_length: seq_lens = batch["seq_lens"] max_seq_len = seq_lens.max() times, labels = times[:, :max_seq_len], labels[:, :max_seq_len] mask = (times != args.padding_id).type(times.dtype) times = times * args.time_scale window_start, window_end = get_window(times=times, window=args.window) events = get_events(times=times, mask=mask, labels=labels, window_start=window_start, window_end=window_end) loss, loss_mask, artifacts = model.neg_log_likelihood(events=events) # [B] if eval_metrics: events_times = events.get_times() log_p, y_pred_mask = model.log_density(query=events_times, events=events) # [B,L,M], [B,L] if args.multi_labels: y_pred = log_p # [B,L,M] labels = events.labels else: y_pred = log_p.argmax(-1).type(log_p.dtype) # [B,L] labels = events.labels.argmax(-1).type(events.labels.dtype) artifacts['y_pred'] = y_pred artifacts['y_true'] = labels artifacts['y_pred_mask'] = y_pred_mask return loss, loss_mask, artifacts
def get_test_events_query(marks=2, batch_size=16, max_seq_len=16, queries=4, padding_id=-1., device=th.device('cpu'), dtype=th.float32): seq_lens = th.randint(low=1, high=max_seq_len, size=[batch_size]) times = [th.rand(size=[seq_len]) for seq_len in seq_lens] labels = [ th.randint(low=0, high=marks, size=[seq_len]) for seq_len in seq_lens ] sort_idx = [th.argsort(x) for x in times] times = [x[idx] for x, idx in zip(times, sort_idx)] labels = [x[idx] for x, idx in zip(labels, sort_idx)] times = pad(times, value=padding_id).type(dtype) labels = pad(labels, value=0) times, labels = times.to(device), labels.to(device) mask = (times != padding_id).type(times.dtype).to(times.device) window_start, window_end = get_window(times=times, window=1.) events = get_events(times=times, mask=mask, labels=labels, window_start=window_start, window_end=window_end, marks=marks) query = th.rand(size=[batch_size, queries]) query = th.sort(query, dim=-1).values query = query.to(device) return events, query
from torch import nn from pprint import pprint from tqdm import tqdm from tpp.models.base.enc_dec import EncDecProcess from tpp.models.encoders.mlp_variable import MLPVariableEncoder from tpp.models.decoders.self_attention_cm import SelfAttentionCmDecoder from tpp.models.decoders.mlp_cm import MLPCmDecoder from tpp.utils.events import get_events, get_window th.manual_seed(0) times = th.Tensor([1, 2, 6]).float().reshape(1, -1) query = th.linspace(start=0.0, end=10.1, steps=10).float().reshape(1, -1) window_start, window_end = get_window(times=times, window=10.) events = get_events(times=times, mask=th.ones_like(times), window_start=window_start, window_end=window_end) # dec = SelfAttentionCmDecoder( # encoding="temporal", # units_mlp=[32, 1], # constraint_mlp="nonneg", # activation_final_mlp="softplus", # attn_activation="sigmoid") dec = MLPCmDecoder(encoding="times_only", units_mlp=[32, 1], constraint_mlp="nonneg", activation_final_mlp="softplus",
def log_figures(model: Process, test_loader: DataLoader, args: Namespace, epoch: Optional[int] = None, images_urls: Optional[dict] = None, save_on_mlflow: Optional[bool] = True): models = dict() models[model.name.replace("_", "-")] = model if args.load_from_dir in [None, "hawkes"]: true_model = HawkesProcess(marks=args.marks) true_model.alpha.data = th.tensor(args.alpha) true_model.beta.data = th.tensor(args.beta) true_model.mu.data = th.tensor(args.mu) true_model.to(args.device) models["ground truth"] = true_model batch = next(iter(test_loader)) times, labels = batch["times"], batch["labels"] times, labels = times.to(args.device), labels.to(args.device) length = (times != args.padding_id).sum(-1) i = th.argmax(length) times = times[i][:20].reshape(1, -1) labels = labels[i][:20].reshape(1, -1, args.marks) mask = (times != args.padding_id).type(times.dtype) times = times * args.time_scale window_start, window_end = get_window(times=times, window=args.window) events = get_events(times=times, mask=mask, labels=labels, window_start=window_start, window_end=window_end) if args.window is not None: query = th.linspace(start=0.001, end=args.window, steps=500) else: query = th.linspace(start=0.001, end=float(events.window_end[0]), steps=500) query = query.reshape(1, -1) query = query.to(device=args.device) event_times = events.times.cpu().detach().numpy().reshape(-1) event_labels = events.labels.cpu().detach().numpy().reshape( event_times.shape[0], -1) idx_times = np.where(event_labels == 1.)[0] event_times = event_times[idx_times] event_labels = np.where((event_labels == 1.))[1] unpadded = event_times != args.padding_id event_times, event_labels = event_times[unpadded], event_labels[unpadded] model_intensities = { k: m.intensity(query=query, events=events) for k, m in models.items() } model_intensities = { k: (ints.cpu().detach().numpy()[0], mask.cpu().detach().numpy()[0]) for k, (ints, mask) in model_intensities.items() } model_artifacts = { k: m.artifacts(query=query, events=events) for k, m in models.items() } model_cumulative_intensities = { k: (v[1], v[2]) for k, v in model_artifacts.items() } model_cumulative_intensities = { k: (ints.cpu().detach().numpy()[0], mask.cpu().detach().numpy()[0]) for k, (ints, mask) in model_cumulative_intensities.items() } model_artifacts = {k: v[3] for k, v in model_artifacts.items()} with open(os.path.join(args.save_dir, 'int_to_codes_to_plot.json'), 'r') as h: int_to_codes_to_plot = json.load(h) with open(os.path.join(args.save_dir, 'int_to_codes.json'), 'r') as h: int_to_codes = json.load(h) with open(os.path.join(args.save_dir, 'codes_to_names.json'), 'r') as h: codes_to_names = json.load(h) int_to_names_to_plot = { k: codes_to_names[v] for k, v in int_to_codes_to_plot.items() } int_to_names = {k: codes_to_names[v] for k, v in int_to_codes.items()} query = query.cpu().detach().numpy()[0] model_intensities = { k: (filter_by_mask(query, mask=mask), filter_by_mask(ints, mask=mask)) for k, (ints, mask) in model_intensities.items() } model_intensities = { k: (q, ints[:, [int(i) for i in int_to_names_to_plot.keys()]]) for k, (q, ints) in model_intensities.items() } model_cumulative_intensities = { k: (filter_by_mask(query, mask=mask), filter_by_mask(ints, mask=mask)) for k, (ints, mask) in model_cumulative_intensities.items() } model_cumulative_intensities = { k: (q, ints[:, [int(i) for i in int_to_names_to_plot.keys()]]) for k, (q, ints) in model_cumulative_intensities.items() } images_urls = plot_attn_weights(model_artifacts=model_artifacts, event_times=event_times, event_labels=event_labels, idx_times=idx_times, query=query, args=args, class_names=list(int_to_names.values()), epoch=epoch, images_urls=images_urls) f, a = fig_hawkes(intensities=model_intensities, cumulative_intensities=model_cumulative_intensities, event_times=event_times, event_labels=event_labels, class_names=int_to_names_to_plot, epoch=epoch) if epoch is not None: epoch_str = get_epoch_str(epoch=epoch, max_epochs=args.train_epochs) intensity_dir = os.path.join(args.plots_dir, "intensity") if epoch is not None: plot_path = os.path.join(intensity_dir, "epoch_" + epoch_str + ".jpg") else: plot_path = intensity_dir + ".jpg" f.savefig(plot_path, dpi=300, bbox_inches='tight') if save_on_mlflow: assert epoch is not None, "Epoch must not be None with mlflow active" mlflow.log_artifact(plot_path, "intensity/epoch_" + epoch_str) if images_urls is not None: images_urls['intensity'].append(plot_path) return images_urls