Example #1
0
def assemble_tensors(data, seq_len):
    x_enc_coarse = data['x_enc_coarse'][:, :seq_len]
    x_enc_coarse = nan_to_value(x_enc_coarse, value=0.0)
    x_enc_fine = data['x_enc_fine'][:, :seq_len]
    x_enc_fine = nan_to_value(x_enc_fine, value=0.0)
    y_enc_fine = data['y_enc_fine'][:, :seq_len]
    y_enc_fine_cat, y_enc_fine_num = pull_parts(y_enc_fine, convert_nans=True)
    y_enc_coarse = data['y_enc_coarse'][:, :seq_len]
    y_enc_coarse_cat, y_enc_coarse_num = pull_parts(y_enc_coarse,
                                                    convert_nans=True)
    input_num_steps = data['effective_num_steps']
    tensors = [
        x_enc_fine, x_enc_coarse, input_num_steps, y_enc_fine_cat,
        y_enc_fine_num, y_enc_coarse_cat, y_enc_coarse_num
    ]
    return tensors
def compute_scaler(x_cat, x_num, quantile_range, to_index=True):
    """x_cat is (batch_size, *, num_actions) and x_num is (batch_size, *, 1)."""
    if to_index:
        x_cat = one_hot_to_index(x_cat)
    x = np.concatenate([x_cat, x_num], axis=-1)
    x = nan_to_value(x, value=-1.0, inplace=False)
    x_unique = np.unique(x.reshape(-1, x.shape[-1]), axis=0)
    x_unique_num = x_unique[..., -1:]
    x_unique_num[x_unique_num == -1.0] = np.nan
    _, scaler = normalise(x_unique_num, strategy='robust', with_centering=False, quantile_range=quantile_range)
    return scaler
Example #3
0
def test_baselines(args):
    checkpoint = torch.load(args.checkpoint)
    fine_labels_path = args.fine_labels_path
    coarse_labels_path = args.coarse_labels_path
    fine_action_to_id = read_action_dictionary(args.fine_action_to_id)
    fine_id_to_action = {
        action_id: action
        for action, action_id in fine_action_to_id.items()
    }
    coarse_action_to_id = read_action_dictionary(args.coarse_action_to_id)
    coarse_id_to_action = {
        action_id: action
        for action, action_id in coarse_action_to_id.items()
    }
    fraction_observed = args.observed_fraction
    ignore_silence_action = args.ignore_silence_action
    do_error_analysis = args.do_error_analysis
    print_coarse_results = args.print_coarse_results
    seq_len = checkpoint['seq_len']
    # Load model
    baseline_type = checkpoint['baseline_type']
    action_level = checkpoint['action_level']
    Baseline = {0: Baseline0, 1: Baseline1, 2: Baseline2}[baseline_type]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = Baseline(**checkpoint['model_creation_args']).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    observed_fine_actions_per_video = []
    unobserved_fine_actions_per_video = []
    predicted_fine_steps_per_video = []
    predicted_fine_actions_per_video = []

    observed_coarse_actions_per_video = []
    unobserved_coarse_actions_per_video = []
    predicted_coarse_steps_per_video = []
    predicted_coarse_actions_per_video = []

    num_frames_per_video = []
    fine_label_files = set(os.listdir(fine_labels_path))
    coarse_label_files = set(os.listdir(coarse_labels_path))
    label_files = sorted(fine_label_files & coarse_label_files)
    for label_file in label_files:
        with open(os.path.join(fine_labels_path, label_file), mode='r') as f:
            fine_actions_per_frame = [line.rstrip() for line in f]
        with open(os.path.join(coarse_labels_path, label_file), mode='r') as f:
            coarse_actions_per_frame = [line.rstrip() for line in f]
        if ignore_silence_action is not None:
            fine_actions_per_frame = [
                fine_action for fine_action in fine_actions_per_frame
                if fine_action != ignore_silence_action
            ]
            coarse_actions_per_frame = [
                coarse_action for coarse_action in coarse_actions_per_frame
                if coarse_action != ignore_silence_action
            ]
        fine_actions_per_frame, coarse_actions_per_frame = \
            extend_smallest_list(fine_actions_per_frame, coarse_actions_per_frame)
        num_frames = len(fine_actions_per_frame)
        num_frames_per_video.append(num_frames)
        num_frames_to_grab = round(num_frames * fraction_observed)
        observed_fine_actions = fine_actions_per_frame[:num_frames_to_grab]
        observed_fine_actions_per_video.append(observed_fine_actions)
        unobserved_fine_actions = fine_actions_per_frame[num_frames_to_grab:]
        observed_coarse_actions = coarse_actions_per_frame[:num_frames_to_grab]
        observed_coarse_actions_per_video.append(observed_coarse_actions)
        unobserved_coarse_actions = coarse_actions_per_frame[
            num_frames_to_grab:]

        tensors, steps, last_action_obs_length = generate_test_datum(
            observed_fine_actions,
            observed_coarse_actions,
            seq_len=seq_len,
            fine_action_to_id=fine_action_to_id,
            coarse_action_to_id=coarse_action_to_id,
            num_frames=num_frames)
        tensors = [nan_to_value(tensor, value=0.0) for tensor in tensors]
        tensors = numpy_to_torch(*tensors, device=device)
        steps = torch.tensor([steps], device=device)
        predictions = predict_future_actions(
            model,
            tensors,
            effective_steps=steps,
            fine_id_to_action=fine_id_to_action,
            coarse_id_to_action=coarse_id_to_action,
            num_frames=num_frames,
            maximum_prediction_length=len(unobserved_fine_actions),
            baseline_type=baseline_type,
            action_level=action_level,
            last_action_obs_length=last_action_obs_length)
        predicted_fine_actions, predicted_fine_steps = predictions
        predicted_fine_actions, predicted_coarse_actions = predicted_fine_actions
        predicted_fine_steps, predicted_coarse_steps = predicted_fine_steps
        predicted_fine_steps_per_video.append(predicted_fine_steps)
        predicted_coarse_steps_per_video.append(predicted_coarse_steps)
        _update_level_predictions(predicted_fine_actions,
                                  predicted_fine_actions_per_video,
                                  unobserved_fine_actions,
                                  unobserved_fine_actions_per_video)
        _update_level_predictions(predicted_coarse_actions,
                                  predicted_coarse_actions_per_video,
                                  unobserved_coarse_actions,
                                  unobserved_coarse_actions_per_video)
    # Performance and Error Analysis
    f1_results_dict = {}
    moc_results_dict = {}
    unobserved_fractions = [0.1, 0.2, 0.3, 0.5, 0.7, 0.8]
    unobserved_fractions = [
        unobserved_fraction for unobserved_fraction in unobserved_fractions
        if fraction_observed + unobserved_fraction <= 1.0
    ]
    for unobserved_fraction in unobserved_fractions:
        save_analysis_path = os.path.join(
            args.checkpoint[:-4],
            str(fraction_observed) + '_' + str(unobserved_fraction))
        if os.path.exists(save_analysis_path):
            clean_directory(save_analysis_path)
        predicted_fine_actions_per_video_sub, unobserved_fine_actions_per_video_sub = [], []
        predicted_coarse_actions_per_video_sub, unobserved_coarse_actions_per_video_sub = [], []
        f1_per_video_fine = []  # file_name, input-level_0.5_f1
        f1_per_video_coarse = []
        sequence_metrics_per_video_fine = [
        ]  # file_name, precision, recall, f1 (regardless of length/class)
        sequence_metrics_per_video_coarse = []
        num_videos = len(predicted_fine_actions_per_video)
        for i in range(num_videos):
            predicted_fine_actions = predicted_fine_actions_per_video[i]
            unobserved_fine_actions = unobserved_fine_actions_per_video[i]
            num_frames_to_grab = num_frames_per_video[i] * unobserved_fraction
            num_frames_to_grab = round(num_frames_to_grab)
            predicted_fine_actions_sub = predicted_fine_actions[:
                                                                num_frames_to_grab]
            predicted_fine_actions_per_video_sub.append(
                predicted_fine_actions_sub)
            unobserved_fine_actions_sub = unobserved_fine_actions[:
                                                                  num_frames_to_grab]
            unobserved_fine_actions_per_video_sub.append(
                unobserved_fine_actions_sub)
            predicted_coarse_actions = predicted_coarse_actions_per_video[i]
            unobserved_coarse_actions = unobserved_coarse_actions_per_video[i]
            predicted_coarse_actions_sub = predicted_coarse_actions[:
                                                                    num_frames_to_grab]
            predicted_coarse_actions_per_video_sub.append(
                predicted_coarse_actions_sub)
            unobserved_coarse_actions_sub = unobserved_coarse_actions[:
                                                                      num_frames_to_grab]
            unobserved_coarse_actions_per_video_sub.append(
                unobserved_coarse_actions_sub)
            if do_error_analysis:
                if baseline_type == 0:
                    if action_level == 'coarse':
                        observed_actions = observed_coarse_actions_per_video[i]
                        predicted_steps = predicted_coarse_steps_per_video[i]
                        unobserved_actions = unobserved_coarse_actions_sub.tolist(
                        )
                        predicted_actions_sub = predicted_coarse_actions_sub
                        unobserved_actions_sub = unobserved_coarse_actions_sub
                        action_to_id = coarse_action_to_id
                    else:
                        observed_actions = observed_fine_actions_per_video[i]
                        predicted_steps = predicted_fine_steps_per_video[i]
                        unobserved_actions = unobserved_fine_actions_sub.tolist(
                        )
                        predicted_actions_sub = predicted_fine_actions_sub
                        unobserved_actions_sub = unobserved_fine_actions_sub
                        action_to_id = fine_action_to_id
                    steps_to_grab = compute_steps_to_grab(
                        predicted_steps, num_frames_to_grab)
                    predicted_steps = predicted_steps[:steps_to_grab]
                    analyse_single_level_observations_and_predictions_per_step(
                        predicted_steps,
                        observed_actions,
                        unobserved_actions,
                        num_frames=num_frames_per_video[i],
                        save_path=save_analysis_path,
                        save_file_name=label_files[i])
                    _, f1_scores = compute_metrics([predicted_actions_sub],
                                                   [unobserved_actions_sub],
                                                   action_to_id=action_to_id)
                    if action_level == 'coarse':
                        f1_per_video_coarse.append(
                            [label_files[i], f1_scores[-1]])
                    else:
                        f1_per_video_fine.append(
                            [label_files[i], f1_scores[-1]])
                    precision, recall, f1 = \
                        action_sequence_metrics(aggregate_actions_and_lengths(predicted_actions_sub.tolist())[0],
                                                aggregate_actions_and_lengths(unobserved_actions_sub.tolist())[0])
                    if action_level == 'coarse':
                        sequence_metrics_per_video_coarse.append(
                            [label_files[i], precision, recall, f1])
                    else:
                        sequence_metrics_per_video_fine.append(
                            [label_files[i], precision, recall, f1])
                else:
                    observed_actions = [
                        coarse_action + '/' + fine_action
                        for coarse_action, fine_action in zip(
                            observed_coarse_actions_per_video[i],
                            observed_fine_actions_per_video[i])
                    ]
                    predicted_fine_steps = predicted_fine_steps_per_video[i]
                    steps_to_grab = compute_steps_to_grab(
                        predicted_fine_steps, num_frames_to_grab)
                    predicted_fine_steps = predicted_fine_steps[:steps_to_grab]
                    predicted_coarse_steps = predicted_coarse_steps_per_video[
                        i][:steps_to_grab]
                    predicted_steps = [
                        (coarse_step[0] + '/' + fine_step[0], fine_step[1])
                        for coarse_step, fine_step in zip(
                            predicted_coarse_steps, predicted_fine_steps)
                    ]
                    unobserved_actions = [
                        coarse_action + '/' + fine_action
                        for coarse_action, fine_action in zip(
                            unobserved_coarse_actions_sub.tolist(),
                            unobserved_fine_actions_sub.tolist())
                    ]
                    analyse_single_level_observations_and_predictions_per_step(
                        predicted_steps,
                        observed_actions,
                        unobserved_actions,
                        num_frames=num_frames_per_video[i],
                        save_path=save_analysis_path,
                        save_file_name=label_files[i])
                    _, f1_scores_fine = compute_metrics(
                        [predicted_fine_actions_sub],
                        [unobserved_fine_actions_sub],
                        action_to_id=fine_action_to_id)
                    f1_per_video_fine.append(
                        [label_files[i], f1_scores_fine[-1]])
                    _, f1_scores_coarse = compute_metrics(
                        [predicted_coarse_actions_sub],
                        [unobserved_coarse_actions_sub],
                        action_to_id=coarse_action_to_id)
                    f1_per_video_coarse.append(
                        [label_files[i], f1_scores_coarse[-1]])
                    precision, recall, f1 = \
                        action_sequence_metrics(aggregate_actions_and_lengths(predicted_fine_actions_sub.tolist())[0],
                                                aggregate_actions_and_lengths(unobserved_fine_actions_sub.tolist())[0])
                    sequence_metrics_per_video_fine.append(
                        [label_files[i], precision, recall, f1])
                    precision, recall, f1 = \
                        action_sequence_metrics(aggregate_actions_and_lengths(predicted_coarse_actions_sub.tolist())[0],
                                                aggregate_actions_and_lengths(unobserved_coarse_actions_sub.tolist())[0])
                    sequence_metrics_per_video_coarse.append(
                        [label_files[i], precision, recall, f1])
        if do_error_analysis:
            if f1_per_video_fine:
                write_results_per_video(f1_per_video_fine,
                                        order_by=None,
                                        metric_name='f1-0.5-fine',
                                        save_path=save_analysis_path)
                write_sequence_results_per_video(
                    sequence_metrics_per_video_fine,
                    save_analysis_path,
                    level='fine')
            if f1_per_video_coarse:
                write_results_per_video(f1_per_video_coarse,
                                        order_by=None,
                                        metric_name='f1-0.5-coarse',
                                        save_path=save_analysis_path)
                write_sequence_results_per_video(
                    sequence_metrics_per_video_coarse,
                    save_analysis_path,
                    level='coarse')
        print('\nObserved fraction: %.2f | Unobserved fraction: %.2f' %
              (fraction_observed, unobserved_fraction))
        if baseline_type == 0 and action_level == 'coarse':
            predicted_actions_per_video_sub = predicted_coarse_actions_per_video_sub
            unobserved_actions_per_video_sub = unobserved_coarse_actions_per_video_sub
            action_to_id = coarse_action_to_id
            print('Coarse')
        else:
            predicted_actions_per_video_sub = predicted_fine_actions_per_video_sub
            unobserved_actions_per_video_sub = unobserved_fine_actions_per_video_sub
            action_to_id = fine_action_to_id
            print('Fine')
        moc, _, _ = compute_moc(
            np.concatenate(predicted_actions_per_video_sub),
            np.concatenate(unobserved_actions_per_video_sub),
            action_to_id=action_to_id)
        if baseline_type == 0 and action_level == 'coarse':
            moc_results_dict[
                f'coarse-moc-{fraction_observed}_{unobserved_fraction}'] = moc
        else:
            moc_results_dict[
                f'fine-moc-{fraction_observed}_{unobserved_fraction}'] = moc
        overlaps, f1_overlap_scores = compute_metrics(
            predicted_actions_per_video_sub,
            unobserved_actions_per_video_sub,
            action_to_id=action_to_id)
        for overlap, overlap_f1_score in zip(overlaps, f1_overlap_scores):
            print('F1@%.2f: %.4f' % (overlap, overlap_f1_score))
            if baseline_type == 0 and action_level == 'coarse':
                f1_results_dict[
                    f'coarse-{fraction_observed}_{unobserved_fraction}_{overlap}'] = overlap_f1_score
            else:
                f1_results_dict[
                    f'fine-{fraction_observed}_{unobserved_fraction}_{overlap}'] = overlap_f1_score
        if baseline_type > 0:
            moc, _, _ = compute_moc(
                np.concatenate(predicted_coarse_actions_per_video_sub),
                np.concatenate(unobserved_coarse_actions_per_video_sub),
                action_to_id=coarse_action_to_id)
            moc_results_dict[
                f'coarse-moc-{fraction_observed}_{unobserved_fraction}'] = moc
            overlaps, f1_overlap_scores = \
                compute_metrics(predicted_coarse_actions_per_video_sub, unobserved_coarse_actions_per_video_sub,
                                action_to_id=coarse_action_to_id)
            for overlap, overlap_f1_score in zip(overlaps, f1_overlap_scores):
                f1_results_dict[
                    f'coarse-{fraction_observed}_{unobserved_fraction}_{overlap}'] = overlap_f1_score
                if print_coarse_results:
                    print('Coarse')
                    print('F1@%.2f: %.4f' % (overlap, overlap_f1_score))
    results_dict = {**f1_results_dict, **moc_results_dict}
    return results_dict
def predict_future_actions(model, input_tensors, fine_id_to_action, coarse_id_to_action, disable_parent_input,
                           num_frames, maximum_prediction_length, observed_fine_actions, observed_coarse_actions,
                           fine_action_to_id, coarse_action_to_id, scalers=None):
    x_enc_fine, x_enc_coarse, dx_enc, dx_enc_layer_zero, x_tra_fine, x_tra_coarse = input_tensors
    dx = [dx_enc, dx_enc_layer_zero]
    with torch.no_grad():
        _, hx = model.encoder_net(x_enc_fine, x_enc_coarse, dx=dx, hx=None)
        hx_tra = [hl[0] for hl in hx] if isinstance(model.encoder_net.encoder_hmgru, HMLSTM) else hx
        (_, y_tra_coarse_rem_prop), _ = model.transition_net(x_tra_fine, x_tra_coarse, hx=hx_tra)
    coarse_la_id = torch.argmax(x_tra_coarse[..., :-2], dim=-1).item()
    coarse_la_label = coarse_id_to_action[coarse_la_id]
    y_tra_coarse_rem_prop = maybe_denormalise(y_tra_coarse_rem_prop.cpu().numpy(),
                                              scaler=scalers.get('y_tra_coarse_scaler'))
    coarse_la_rem_len = round(y_tra_coarse_rem_prop.item() * num_frames)
    predicted_coarse_actions = [coarse_la_label] * coarse_la_rem_len
    predicted_coarse_steps = [(coarse_la_label, coarse_la_rem_len)]

    # Generate input tensors again.
    new_observed_coarse_actions = observed_coarse_actions + predicted_coarse_actions
    input_seq_len = x_enc_fine.size(1)
    input_tensors = generate_test_datum(observed_fine_actions, new_observed_coarse_actions, input_seq_len=input_seq_len,
                                        fine_action_to_id=fine_action_to_id, coarse_action_to_id=coarse_action_to_id,
                                        disable_parent_input=disable_parent_input,
                                        num_frames=num_frames, scalers=scalers, coarse_is_complete=True)
    input_tensors = [nan_to_value(tensor, value=0.0) for tensor in input_tensors]
    input_tensors = numpy_to_torch(*input_tensors, device=x_enc_fine.device)
    x_enc_fine, x_enc_coarse, dx_enc, dx_enc_layer_zero, x_tra_fine, x_tra_coarse = input_tensors
    dx = [dx_enc, dx_enc_layer_zero]
    with torch.no_grad():
        _, hx, hxs = model.encoder_net(x_enc_fine, x_enc_coarse, dx=dx, hx=None, return_all_hidden_states=True)
        hx_tra = [hl[0] for hl in hx] if isinstance(model.encoder_net.encoder_hmgru, HMLSTM) else hx
        (y_tra_fine_rem_rel_prop, _), hx_tra = model.transition_net(x_tra_fine, x_tra_coarse, hx=hx_tra)
        try:
            if not model.disable_transition_layer:
                if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                    for i, hl in enumerate(hx_tra):
                        hx[i][0] = hl
                else:
                    hx = hx_tra
                hxs[0] = torch.cat([hxs[0], hx_tra[0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx_tra[1].unsqueeze(1)], dim=1)
        except AttributeError:
            if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                for i, hl in enumerate(hx_tra):
                    hx[i][0] = hl
            else:
                hx = hx_tra
            hxs[0] = torch.cat([hxs[0], hx_tra[0].unsqueeze(1)], dim=1)
            hxs[1] = torch.cat([hxs[1], hx_tra[1].unsqueeze(1)], dim=1)

    num_coarse_actions = len(coarse_action_to_id)
    if disable_parent_input:
        fine_la_id = torch.argmax(x_tra_fine[..., :-2], dim=-1).item()
    else:
        fine_la_id = torch.argmax(x_tra_fine[..., num_coarse_actions:-2], dim=-1).item()
    fine_la_label = fine_id_to_action[fine_la_id]
    y_tra_fine_rem_rel_prop = maybe_denormalise(y_tra_fine_rem_rel_prop.cpu().numpy(),
                                                scaler=scalers.get('y_tra_fine_scaler'))
    coarse_tra_len_prop = x_tra_coarse[..., -1].item() + y_tra_coarse_rem_prop.item()
    fine_la_rem_len = round(y_tra_fine_rem_rel_prop.item() * coarse_tra_len_prop * num_frames)
    predicted_fine_actions = [fine_la_label] * fine_la_rem_len
    predicted_fine_steps = [(fine_la_label, fine_la_rem_len)]
    # Decoder
    dtype, device = x_enc_fine.dtype, x_enc_fine.device
    x_dec_cat_coarse = x_tra_coarse[..., :-2]
    x_dec_num_coarse = x_tra_coarse[..., -2:-1] + torch.tensor(y_tra_coarse_rem_prop, dtype=dtype, device=device)
    x_dec_coarse = torch.cat([x_dec_cat_coarse, x_dec_num_coarse], dim=-1)

    x_dec_cat_fine = x_tra_fine[..., :-2]
    acc_fine_proportion = x_tra_fine[..., -2:-1] + torch.tensor(y_tra_fine_rem_rel_prop, dtype=dtype, device=device)
    acc_fine_proportion = acc_fine_proportion.item()
    x_dec_num_fine = torch.tensor([[acc_fine_proportion]], dtype=dtype, device=device)
    x_dec_fine = torch.cat([x_dec_cat_fine, x_dec_num_fine], dim=-1)

    coarse_la_obs_prop = maybe_denormalise(x_tra_coarse[..., -1:].cpu().numpy(),
                                           scaler=scalers.get('x_tra_coarse_scaler'))
    coarse_la_prop = coarse_la_obs_prop.item() + y_tra_coarse_rem_prop.item()
    d_fine, d_fines = 0.0, []
    decoder_net, output_seq_len = model.decoder_net, model.decoder_net.output_seq_len
    coarse_exceed_first_time, total_coarse_length = True, 0
    with torch.no_grad():
        for t in range(output_seq_len):
            # Predict
            if model.model_v2:
                x_dec_fine_ = x_dec_fine[0]
                hx_ = [hx[0][0], hx[1][0]]
                y_dec_fine_logits, y_dec_fine_rel_prop, hx_fine = decoder_net.single_step_fine(x_dec_fine_, d_fine,
                                                                                               hx_)
                hx[0] = hx_fine.unsqueeze(0)
            else:
                y_dec_fine_logits, y_dec_fine_rel_prop, hx[0] = \
                    decoder_net.single_step_fine(x_dec_fine, d_fine, hx)
            # Process Prediction
            fine_na_label, _ = next_action_info(y_dec_fine_logits, y_dec_fine_rel_prop, fine_id_to_action, num_frames)
            if acc_fine_proportion >= 1.0 or fine_na_label is None:
                acc_fine_proportion, d_fine = 0.0, 1.0
                if model.model_v2:
                    x_dec_coarse_ = x_dec_coarse[0]
                    hx_ = [hx[0][0], hx[1][0]]
                    y_dec_coarse_logits, y_dec_coarse_prop, hx_coarse = \
                        decoder_net.single_step_coarse(x_dec_coarse_, hx_)
                    hx[1] = hx_coarse.unsqueeze(0)
                else:
                    y_dec_coarse_logits, y_dec_coarse_prop, hx[1] = \
                        decoder_net.single_step_coarse(x_dec_coarse, d_fine, hx)
                y_dec_coarse_prop = maybe_denormalise(y_dec_coarse_prop.cpu().numpy(),
                                                      scaler=scalers.get('y_dec_coarse_scaler'))
                coarse_na_label, coarse_na_len = next_action_info(y_dec_coarse_logits, y_dec_coarse_prop,
                                                                  coarse_id_to_action, num_frames)
                if coarse_na_label is None:
                    break
                predicted_coarse_actions += [coarse_na_label] * coarse_na_len
                predicted_coarse_steps.append((coarse_na_label, coarse_na_len))
                coarse_la_prop = y_dec_coarse_prop.item()
                x_dec_cat_coarse = logit2one_hot(y_dec_coarse_logits)
                if model.with_final_action:
                    x_dec_cat_coarse = x_dec_cat_coarse[..., :-1]
                x_dec_coarse[..., :-1] = x_dec_cat_coarse
                x_dec_coarse[..., -1] += coarse_la_prop
                predicted_fine_steps.append((None, None))
                if model.model_v3 and decoder_net.input_soft_parent:  # Prepare x_dec_cat_coarse for fine steps
                    if model.with_final_action:
                        x_dec_cat_coarse = torch.softmax(y_dec_coarse_logits[..., :-1], dim=-1)
                    else:
                        x_dec_cat_coarse = torch.softmax(y_dec_coarse_logits, dim=-1)
            else:
                y_dec_fine_rel_prop = maybe_denormalise(y_dec_fine_rel_prop.cpu().numpy(),
                                                        scaler=scalers.get('y_dec_fine_scaler'))
                excess = 0.0
                fine_na_label, fine_na_len = next_action_info(y_dec_fine_logits, y_dec_fine_rel_prop - excess,
                                                              fine_id_to_action, num_frames,
                                                              parent_la_prop=coarse_la_prop)
                predicted_fine_actions += [fine_na_label] * fine_na_len
                predicted_fine_steps.append((fine_na_label, fine_na_len))
                acc_fine_proportion += y_dec_fine_rel_prop.item()
                predicted_coarse_steps.append((None, None))
                d_fine = 0.0
            # Post-process
            d_fines.append(d_fine)
            if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                hxs[0] = torch.cat([hxs[0], hx[0][0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx[1][0].unsqueeze(1)], dim=1)
            else:
                hxs[0] = torch.cat([hxs[0], hx[0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx[1].unsqueeze(1)], dim=1)
            x_dec_cat_fine = logit2one_hot(y_dec_fine_logits)
            if model.with_final_action:
                x_dec_cat_fine = x_dec_cat_fine[..., :-1]
            x_dec_cat_fine = x_dec_cat_fine * float(acc_fine_proportion > 0.0)
            if not disable_parent_input:
                x_dec_cat_fine = torch.cat([x_dec_cat_coarse, x_dec_cat_fine], dim=-1)
            x_dec_num_fine = torch.tensor([[acc_fine_proportion]], dtype=dtype, device=device)
            x_dec_fine = torch.cat([x_dec_cat_fine, x_dec_num_fine], dim=-1)
            coarse_exceed = len(predicted_coarse_actions) >= maximum_prediction_length
            fine_exceed = len(predicted_fine_actions) >= maximum_prediction_length
            if coarse_exceed and fine_exceed:
                break
            if coarse_exceed:
                if coarse_exceed_first_time:
                    coarse_exceed_first_time = False
                    total_coarse_length = len(predicted_coarse_actions)
                elif len(predicted_coarse_actions) > total_coarse_length:
                    predicted_coarse_steps = predicted_coarse_steps[:-1]
                    predicted_fine_steps = predicted_fine_steps[:-1]
                    break
    if model.with_final_action:
        fine_steps = [(None, None)] + predicted_fine_steps
        coarse_steps = predicted_coarse_steps[:1] + [(None, None)] + predicted_coarse_steps[1:]
        coarse_steps = maybe_rebalance_steps(coarse_steps, maximum_prediction_length)
        predicted_fine_steps, predicted_coarse_steps = fix_steps(fine_steps, coarse_steps)
        predicted_fine_actions = actions_from_steps(predicted_fine_steps)
        predicted_coarse_actions = actions_from_steps(predicted_coarse_steps)
    predicted_actions = predicted_fine_actions, predicted_coarse_actions
    predicted_steps = predicted_fine_steps, predicted_coarse_steps
    return predicted_actions, predicted_steps, d_fines
def test_hera(args):
    checkpoint = torch.load(args.checkpoint)
    fine_labels_path = args.fine_labels_path
    coarse_labels_path = args.coarse_labels_path
    fine_action_to_id = read_action_dictionary(args.fine_action_to_id)
    fine_id_to_action = {action_id: action for action, action_id in fine_action_to_id.items()}
    coarse_action_to_id = read_action_dictionary(args.coarse_action_to_id)
    coarse_id_to_action = {action_id: action for action, action_id in coarse_action_to_id.items()}
    fraction_observed = args.observed_fraction
    ignore_silence_action = args.ignore_silence_action
    do_error_analysis = args.do_error_analysis
    do_future_performance_analysis = args.do_future_performance_analysis
    do_flush_analysis = args.do_flush_analysis
    input_seq_len = checkpoint['input_seq_len']
    scalers = checkpoint.get('scalers', None)
    disable_parent_input = checkpoint['disable_parent_input']
    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = HERA(**checkpoint['model_creation_args']).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    observed_fine_actions_per_video, observed_coarse_actions_per_video = [], []
    fine_transition_action_per_video, coarse_transition_action_per_video = [], []
    flushes_per_video, ground_truth_flushes_per_video = [], []
    predicted_fine_actions_per_video, predicted_coarse_actions_per_video = [], []
    predicted_fine_steps_per_video, predicted_coarse_steps_per_video = [], []
    unobserved_fine_actions_per_video, unobserved_coarse_actions_per_video = [], []
    fine_label_files = set(os.listdir(fine_labels_path))
    coarse_label_files = set(os.listdir(coarse_labels_path))
    label_files = sorted(fine_label_files & coarse_label_files)
    for label_file in label_files:
        with open(os.path.join(fine_labels_path, label_file), mode='r') as f:
            fine_actions_per_frame = [line.rstrip() for line in f]
        with open(os.path.join(coarse_labels_path, label_file), mode='r') as f:
            coarse_actions_per_frame = [line.rstrip() for line in f]
        if ignore_silence_action is not None:
            fine_actions_per_frame = [fine_action for fine_action in fine_actions_per_frame
                                      if fine_action != ignore_silence_action]
            coarse_actions_per_frame = [coarse_action for coarse_action in coarse_actions_per_frame
                                        if coarse_action != ignore_silence_action]
        fine_actions_per_frame, coarse_actions_per_frame = \
            extend_smallest_list(fine_actions_per_frame, coarse_actions_per_frame)
        observed_fine_actions, unobserved_fine_actions = split_observed_actions(fine_actions_per_frame,
                                                                                fraction_observed=fraction_observed)
        observed_fine_actions_per_video.append(observed_fine_actions)
        fine_transition_action_per_video.append(observed_fine_actions[-1])
        observed_coarse_actions, unobserved_coarse_actions = split_observed_actions(coarse_actions_per_frame,
                                                                                    fraction_observed=fraction_observed)
        observed_coarse_actions_per_video.append(observed_coarse_actions)
        coarse_transition_action_per_video.append(observed_coarse_actions[-1])
        tensors = generate_test_datum(observed_fine_actions, observed_coarse_actions, input_seq_len=input_seq_len,
                                      fine_action_to_id=fine_action_to_id, coarse_action_to_id=coarse_action_to_id,
                                      disable_parent_input=disable_parent_input,
                                      num_frames=len(fine_actions_per_frame), scalers=scalers, coarse_is_complete=False)
        tensors = [nan_to_value(tensor, value=0.0) for tensor in tensors]
        tensors = numpy_to_torch(*tensors, device=device)
        predicted_actions, predicted_steps, dx_dec_fine = \
            predict_future_actions(model, tensors, fine_id_to_action=fine_id_to_action,
                                   coarse_id_to_action=coarse_id_to_action,
                                   disable_parent_input=disable_parent_input,
                                   num_frames=len(fine_actions_per_frame),
                                   maximum_prediction_length=len(unobserved_fine_actions),
                                   observed_fine_actions=observed_fine_actions,
                                   observed_coarse_actions=observed_coarse_actions,
                                   fine_action_to_id=fine_action_to_id, coarse_action_to_id=coarse_action_to_id,
                                   scalers=scalers)
        flushes_per_video.append(dx_dec_fine)
        ground_truth_flushes = compute_ground_truth_flushes(observed_coarse_actions[-1], observed_fine_actions[-1],
                                                            unobserved_coarse_actions, unobserved_fine_actions)
        ground_truth_flushes_per_video.append(ground_truth_flushes)
        predicted_fine_steps, predicted_coarse_steps = predicted_steps
        predicted_fine_steps_per_video.append(predicted_fine_steps)
        predicted_coarse_steps_per_video.append(predicted_coarse_steps)
        predicted_fine_actions, predicted_coarse_actions = predicted_actions
        if not predicted_fine_actions:
            predicted_fine_actions = ['FAILED_TO_PREDICT']
        predicted_fine_actions = extend_or_trim_predicted_actions(predicted_fine_actions, unobserved_fine_actions)
        predicted_fine_actions = np.array(predicted_fine_actions)
        predicted_fine_actions_per_video.append(predicted_fine_actions)
        unobserved_fine_actions = np.array(unobserved_fine_actions)
        unobserved_fine_actions_per_video.append(unobserved_fine_actions)
        if not predicted_coarse_actions:
            predicted_coarse_actions = ['FAILED_TO_PREDICT']
        predicted_coarse_actions = extend_or_trim_predicted_actions(predicted_coarse_actions, unobserved_coarse_actions)
        predicted_coarse_actions = np.array(predicted_coarse_actions)
        predicted_coarse_actions_per_video.append(predicted_coarse_actions)
        unobserved_coarse_actions = np.array(unobserved_coarse_actions)
        unobserved_coarse_actions_per_video.append(unobserved_coarse_actions)
    # Performance and Error Analysis
    f1_results_dict = {}
    moc_results_dict = {}
    unobserved_fractions = [0.1, 0.2, 0.3, 0.5, 0.7, 0.8]
    unobserved_fractions = [unobserved_fraction for unobserved_fraction in unobserved_fractions
                            if fraction_observed + unobserved_fraction <= 1.0]
    for unobserved_fraction in unobserved_fractions:
        save_analysis_path = os.path.join(args.checkpoint[:-4], str(fraction_observed) + '_' + str(unobserved_fraction))
        global_fraction_unobserved = 1.0 - fraction_observed
        predicted_fine_actions_per_video_sub, unobserved_fine_actions_per_video_sub = [], []
        predicted_coarse_actions_per_video_sub, unobserved_coarse_actions_per_video_sub = [], []
        f1_per_video = []  # file_name, coarse_0.5_f1, fine_0.5_f1
        for i, (predicted_fine_actions, unobserved_fine_actions, predicted_coarse_actions, unobserved_coarse_actions) in \
                enumerate(zip(predicted_fine_actions_per_video, unobserved_fine_actions_per_video,
                              predicted_coarse_actions_per_video, unobserved_coarse_actions_per_video)):
            num_frames_to_grab = (len(unobserved_fine_actions) / global_fraction_unobserved) * unobserved_fraction
            num_frames_to_grab = round(num_frames_to_grab)
            predicted_fine_actions_sub = predicted_fine_actions[:num_frames_to_grab]
            predicted_fine_actions_per_video_sub.append(predicted_fine_actions_sub)
            unobserved_fine_actions_sub = unobserved_fine_actions[:num_frames_to_grab]
            unobserved_fine_actions_per_video_sub.append(unobserved_fine_actions_sub)
            predicted_coarse_actions_sub = predicted_coarse_actions[:num_frames_to_grab]
            predicted_coarse_actions_per_video_sub.append(predicted_coarse_actions_sub)
            unobserved_coarse_actions_sub = unobserved_coarse_actions[:num_frames_to_grab]
            unobserved_coarse_actions_per_video_sub.append(unobserved_coarse_actions_sub)
            if do_error_analysis:
                predicted_fine_steps = predicted_fine_steps_per_video[i]
                steps_to_grab = compute_steps_to_grab(predicted_fine_steps, num_frames_to_grab)
                predicted_fine_steps = predicted_fine_steps[:steps_to_grab]
                predicted_coarse_steps = predicted_coarse_steps_per_video[i][:steps_to_grab]
                coarse_actions_per_frame = (observed_coarse_actions_per_video[i] +
                                            unobserved_coarse_actions_per_video[i].tolist())
                analyse_hierarchical_observations_and_predictions(predicted_fine_steps,
                                                                  predicted_coarse_steps,
                                                                  observed_fine_actions_per_video[i],
                                                                  observed_coarse_actions_per_video[i],
                                                                  unobserved_fine_actions_sub,
                                                                  unobserved_coarse_actions_sub,
                                                                  coarse_actions_per_frame_full=coarse_actions_per_frame,
                                                                  save_path=save_analysis_path,
                                                                  save_file_name=label_files[i])
                _, f1_fine_scores = compute_metrics([predicted_fine_actions_sub],
                                                    [unobserved_fine_actions_sub],
                                                    action_to_id=fine_action_to_id)
                _, f1_coarse_scores = compute_metrics([predicted_coarse_actions_sub],
                                                      [unobserved_coarse_actions_sub],
                                                      action_to_id=coarse_action_to_id)
                f1_per_video.append([label_files[i], f1_coarse_scores[-1], f1_fine_scores[-1]])
        if do_error_analysis:
            write_results_per_video(f1_per_video, order_by='coarse', metric_name='f1-0.5', save_path=save_analysis_path)
            write_results_per_video(f1_per_video, order_by='fine', metric_name='f1-0.5', save_path=save_analysis_path)
        if do_future_performance_analysis:
            analyse_performance_per_future_action(predicted_coarse_actions_per_video_sub,
                                                  unobserved_coarse_actions_per_video_sub,
                                                  transition_action_per_video=coarse_transition_action_per_video,
                                                  save_path=save_analysis_path, extra_str='Coarse')
            analyse_performance_per_future_action(predicted_fine_actions_per_video_sub,
                                                  unobserved_fine_actions_per_video_sub,
                                                  transition_action_per_video=fine_transition_action_per_video,
                                                  save_path=save_analysis_path, mode='a', extra_str='Fine')
        print('\nObserved fraction: %.2f | Unobserved fraction: %.2f' % (fraction_observed, unobserved_fraction))
        print('-> Fine')
        overlaps, f1_overlap_scores = compute_metrics(predicted_fine_actions_per_video_sub,
                                                      unobserved_fine_actions_per_video_sub,
                                                      action_to_id=fine_action_to_id)
        for overlap, overlap_f1_score in zip(overlaps, f1_overlap_scores):
            print('F1@%.2f: %.4f' % (overlap, overlap_f1_score))
            f1_results_dict[f'fine-{fraction_observed}_{unobserved_fraction}_{overlap}'] = overlap_f1_score
        fine_moc, _, _ = compute_moc(np.concatenate(predicted_fine_actions_per_video_sub),
                                     np.concatenate(unobserved_fine_actions_per_video_sub),
                                     action_to_id=fine_action_to_id)
        print(f'MoC: {fine_moc:.4f}')
        moc_results_dict[f'fine-moc-{fraction_observed}_{unobserved_fraction}'] = fine_moc
        print('-> Coarse')
        overlaps, f1_overlap_scores = compute_metrics(predicted_coarse_actions_per_video_sub,
                                                      unobserved_coarse_actions_per_video_sub,
                                                      action_to_id=coarse_action_to_id)
        for overlap, overlap_f1_score in zip(overlaps, f1_overlap_scores):
            print('F1@%.2f: %.4f' % (overlap, overlap_f1_score))
            f1_results_dict[f'coarse-{fraction_observed}_{unobserved_fraction}_{overlap}'] = overlap_f1_score
        coarse_moc, _, _ = compute_moc(np.concatenate(predicted_coarse_actions_per_video_sub),
                                       np.concatenate(unobserved_coarse_actions_per_video_sub),
                                       action_to_id=coarse_action_to_id)
        print(f'MoC: {coarse_moc:.4f}')
        moc_results_dict[f'coarse-moc-{fraction_observed}_{unobserved_fraction}'] = coarse_moc
    if do_flush_analysis:
        analyse_flushes_hierarchical(flushes_per_video, ground_truth_flushes_per_video,
                                     label_files, model.decoder_net.output_seq_len,
                                     save_path=args.checkpoint[:-4], encoder=False)
    results_dict = {**f1_results_dict, **moc_results_dict}
    return results_dict