Пример #1
0
    def forward(self,
                x_enc_fine,
                x_enc_coarse,
                hx=None,
                teacher_prob=None,
                test_mode=False,
                effective_num_steps=None):
        x_enc = torch.cat([x_enc_coarse[..., :-1], x_enc_fine], dim=-1)
        x_enc = self.embedding(x_enc)
        output, _ = self.rnn(x_enc, hx=hx)
        if not test_mode:
            y_logits_coarse = F.log_softmax(self.coarse_action(output), dim=-1)
            y_logits_fine, y_lens_fine = self.fine_action(output)
        else:
            y_logits_coarse, y_logits_fine, y_lens_fine = [], [], []
            batch_size, seq_len = x_enc.size(0), x_enc.size(1)
            for i in range(batch_size):
                y_ex_logits_coarse, y_ex_logits_fine, y_ex_lens_fine = [], [], []
                steps = int(effective_num_steps[i].item())
                step_hx = output[i:i + 1][:, steps - 1]
                y_step_logits_coarse = F.log_softmax(
                    self.coarse_action(step_hx), dim=-1)
                y_step_logits_fine, y_step_lens_fine = self.fine_action(
                    step_hx)
                x_coarse_cat = logit2one_hot(y_step_logits_coarse.detach())
                x_fine_cat = logit2one_hot(y_step_logits_fine.detach())
                if self.with_final_action:
                    x_coarse_cat = x_coarse_cat[..., :-1]
                    x_fine_cat = x_fine_cat[..., :-1]
                x_fine_num = x_enc_fine[i:i + 1, steps - 1,
                                        -1:] + y_step_lens_fine.detach()
                x_enc_step = torch.cat([x_coarse_cat, x_fine_cat, x_fine_num],
                                       dim=-1)
                for t in range(seq_len):
                    y_step_logits_coarse, y_step_logits_fine, y_step_lens_fine, step_hx = \
                        self.single_step(x_enc_step, hx_step=step_hx)
                    y_ex_logits_coarse.append(y_step_logits_coarse)
                    y_ex_logits_fine.append(y_step_logits_fine)
                    y_ex_lens_fine.append(y_step_lens_fine)

                    x_coarse_cat = logit2one_hot(y_step_logits_coarse.detach())
                    x_fine_cat = logit2one_hot(y_step_logits_fine.detach())
                    if self.with_final_action:
                        x_coarse_cat = x_coarse_cat[..., :-1]
                        x_fine_cat = x_fine_cat[..., :-1]
                    x_fine_num = x_enc_step[...,
                                            -1:] + y_step_lens_fine.detach()
                    x_enc_step = torch.cat(
                        [x_coarse_cat, x_fine_cat, x_fine_num], dim=-1)
                y_logits_coarse.append(torch.stack(y_ex_logits_coarse, dim=1))
                y_logits_fine.append(torch.stack(y_ex_logits_fine, dim=1))
                y_lens_fine.append(torch.stack(y_ex_lens_fine, dim=1))
            y_logits_coarse = torch.cat(y_logits_coarse, dim=0)
            y_logits_fine = torch.cat(y_logits_fine, dim=0)
            y_lens_fine = torch.cat(y_lens_fine, dim=0)
        return y_logits_fine, y_lens_fine, y_logits_coarse
def predict_future_actions(model, input_tensors, fine_id_to_action, coarse_id_to_action, disable_parent_input,
                           num_frames, maximum_prediction_length, observed_fine_actions, observed_coarse_actions,
                           fine_action_to_id, coarse_action_to_id, scalers=None):
    x_enc_fine, x_enc_coarse, dx_enc, dx_enc_layer_zero, x_tra_fine, x_tra_coarse = input_tensors
    dx = [dx_enc, dx_enc_layer_zero]
    with torch.no_grad():
        _, hx = model.encoder_net(x_enc_fine, x_enc_coarse, dx=dx, hx=None)
        hx_tra = [hl[0] for hl in hx] if isinstance(model.encoder_net.encoder_hmgru, HMLSTM) else hx
        (_, y_tra_coarse_rem_prop), _ = model.transition_net(x_tra_fine, x_tra_coarse, hx=hx_tra)
    coarse_la_id = torch.argmax(x_tra_coarse[..., :-2], dim=-1).item()
    coarse_la_label = coarse_id_to_action[coarse_la_id]
    y_tra_coarse_rem_prop = maybe_denormalise(y_tra_coarse_rem_prop.cpu().numpy(),
                                              scaler=scalers.get('y_tra_coarse_scaler'))
    coarse_la_rem_len = round(y_tra_coarse_rem_prop.item() * num_frames)
    predicted_coarse_actions = [coarse_la_label] * coarse_la_rem_len
    predicted_coarse_steps = [(coarse_la_label, coarse_la_rem_len)]

    # Generate input tensors again.
    new_observed_coarse_actions = observed_coarse_actions + predicted_coarse_actions
    input_seq_len = x_enc_fine.size(1)
    input_tensors = generate_test_datum(observed_fine_actions, new_observed_coarse_actions, input_seq_len=input_seq_len,
                                        fine_action_to_id=fine_action_to_id, coarse_action_to_id=coarse_action_to_id,
                                        disable_parent_input=disable_parent_input,
                                        num_frames=num_frames, scalers=scalers, coarse_is_complete=True)
    input_tensors = [nan_to_value(tensor, value=0.0) for tensor in input_tensors]
    input_tensors = numpy_to_torch(*input_tensors, device=x_enc_fine.device)
    x_enc_fine, x_enc_coarse, dx_enc, dx_enc_layer_zero, x_tra_fine, x_tra_coarse = input_tensors
    dx = [dx_enc, dx_enc_layer_zero]
    with torch.no_grad():
        _, hx, hxs = model.encoder_net(x_enc_fine, x_enc_coarse, dx=dx, hx=None, return_all_hidden_states=True)
        hx_tra = [hl[0] for hl in hx] if isinstance(model.encoder_net.encoder_hmgru, HMLSTM) else hx
        (y_tra_fine_rem_rel_prop, _), hx_tra = model.transition_net(x_tra_fine, x_tra_coarse, hx=hx_tra)
        try:
            if not model.disable_transition_layer:
                if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                    for i, hl in enumerate(hx_tra):
                        hx[i][0] = hl
                else:
                    hx = hx_tra
                hxs[0] = torch.cat([hxs[0], hx_tra[0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx_tra[1].unsqueeze(1)], dim=1)
        except AttributeError:
            if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                for i, hl in enumerate(hx_tra):
                    hx[i][0] = hl
            else:
                hx = hx_tra
            hxs[0] = torch.cat([hxs[0], hx_tra[0].unsqueeze(1)], dim=1)
            hxs[1] = torch.cat([hxs[1], hx_tra[1].unsqueeze(1)], dim=1)

    num_coarse_actions = len(coarse_action_to_id)
    if disable_parent_input:
        fine_la_id = torch.argmax(x_tra_fine[..., :-2], dim=-1).item()
    else:
        fine_la_id = torch.argmax(x_tra_fine[..., num_coarse_actions:-2], dim=-1).item()
    fine_la_label = fine_id_to_action[fine_la_id]
    y_tra_fine_rem_rel_prop = maybe_denormalise(y_tra_fine_rem_rel_prop.cpu().numpy(),
                                                scaler=scalers.get('y_tra_fine_scaler'))
    coarse_tra_len_prop = x_tra_coarse[..., -1].item() + y_tra_coarse_rem_prop.item()
    fine_la_rem_len = round(y_tra_fine_rem_rel_prop.item() * coarse_tra_len_prop * num_frames)
    predicted_fine_actions = [fine_la_label] * fine_la_rem_len
    predicted_fine_steps = [(fine_la_label, fine_la_rem_len)]
    # Decoder
    dtype, device = x_enc_fine.dtype, x_enc_fine.device
    x_dec_cat_coarse = x_tra_coarse[..., :-2]
    x_dec_num_coarse = x_tra_coarse[..., -2:-1] + torch.tensor(y_tra_coarse_rem_prop, dtype=dtype, device=device)
    x_dec_coarse = torch.cat([x_dec_cat_coarse, x_dec_num_coarse], dim=-1)

    x_dec_cat_fine = x_tra_fine[..., :-2]
    acc_fine_proportion = x_tra_fine[..., -2:-1] + torch.tensor(y_tra_fine_rem_rel_prop, dtype=dtype, device=device)
    acc_fine_proportion = acc_fine_proportion.item()
    x_dec_num_fine = torch.tensor([[acc_fine_proportion]], dtype=dtype, device=device)
    x_dec_fine = torch.cat([x_dec_cat_fine, x_dec_num_fine], dim=-1)

    coarse_la_obs_prop = maybe_denormalise(x_tra_coarse[..., -1:].cpu().numpy(),
                                           scaler=scalers.get('x_tra_coarse_scaler'))
    coarse_la_prop = coarse_la_obs_prop.item() + y_tra_coarse_rem_prop.item()
    d_fine, d_fines = 0.0, []
    decoder_net, output_seq_len = model.decoder_net, model.decoder_net.output_seq_len
    coarse_exceed_first_time, total_coarse_length = True, 0
    with torch.no_grad():
        for t in range(output_seq_len):
            # Predict
            if model.model_v2:
                x_dec_fine_ = x_dec_fine[0]
                hx_ = [hx[0][0], hx[1][0]]
                y_dec_fine_logits, y_dec_fine_rel_prop, hx_fine = decoder_net.single_step_fine(x_dec_fine_, d_fine,
                                                                                               hx_)
                hx[0] = hx_fine.unsqueeze(0)
            else:
                y_dec_fine_logits, y_dec_fine_rel_prop, hx[0] = \
                    decoder_net.single_step_fine(x_dec_fine, d_fine, hx)
            # Process Prediction
            fine_na_label, _ = next_action_info(y_dec_fine_logits, y_dec_fine_rel_prop, fine_id_to_action, num_frames)
            if acc_fine_proportion >= 1.0 or fine_na_label is None:
                acc_fine_proportion, d_fine = 0.0, 1.0
                if model.model_v2:
                    x_dec_coarse_ = x_dec_coarse[0]
                    hx_ = [hx[0][0], hx[1][0]]
                    y_dec_coarse_logits, y_dec_coarse_prop, hx_coarse = \
                        decoder_net.single_step_coarse(x_dec_coarse_, hx_)
                    hx[1] = hx_coarse.unsqueeze(0)
                else:
                    y_dec_coarse_logits, y_dec_coarse_prop, hx[1] = \
                        decoder_net.single_step_coarse(x_dec_coarse, d_fine, hx)
                y_dec_coarse_prop = maybe_denormalise(y_dec_coarse_prop.cpu().numpy(),
                                                      scaler=scalers.get('y_dec_coarse_scaler'))
                coarse_na_label, coarse_na_len = next_action_info(y_dec_coarse_logits, y_dec_coarse_prop,
                                                                  coarse_id_to_action, num_frames)
                if coarse_na_label is None:
                    break
                predicted_coarse_actions += [coarse_na_label] * coarse_na_len
                predicted_coarse_steps.append((coarse_na_label, coarse_na_len))
                coarse_la_prop = y_dec_coarse_prop.item()
                x_dec_cat_coarse = logit2one_hot(y_dec_coarse_logits)
                if model.with_final_action:
                    x_dec_cat_coarse = x_dec_cat_coarse[..., :-1]
                x_dec_coarse[..., :-1] = x_dec_cat_coarse
                x_dec_coarse[..., -1] += coarse_la_prop
                predicted_fine_steps.append((None, None))
                if model.model_v3 and decoder_net.input_soft_parent:  # Prepare x_dec_cat_coarse for fine steps
                    if model.with_final_action:
                        x_dec_cat_coarse = torch.softmax(y_dec_coarse_logits[..., :-1], dim=-1)
                    else:
                        x_dec_cat_coarse = torch.softmax(y_dec_coarse_logits, dim=-1)
            else:
                y_dec_fine_rel_prop = maybe_denormalise(y_dec_fine_rel_prop.cpu().numpy(),
                                                        scaler=scalers.get('y_dec_fine_scaler'))
                excess = 0.0
                fine_na_label, fine_na_len = next_action_info(y_dec_fine_logits, y_dec_fine_rel_prop - excess,
                                                              fine_id_to_action, num_frames,
                                                              parent_la_prop=coarse_la_prop)
                predicted_fine_actions += [fine_na_label] * fine_na_len
                predicted_fine_steps.append((fine_na_label, fine_na_len))
                acc_fine_proportion += y_dec_fine_rel_prop.item()
                predicted_coarse_steps.append((None, None))
                d_fine = 0.0
            # Post-process
            d_fines.append(d_fine)
            if isinstance(model.encoder_net.encoder_hmgru, HMLSTM):
                hxs[0] = torch.cat([hxs[0], hx[0][0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx[1][0].unsqueeze(1)], dim=1)
            else:
                hxs[0] = torch.cat([hxs[0], hx[0].unsqueeze(1)], dim=1)
                hxs[1] = torch.cat([hxs[1], hx[1].unsqueeze(1)], dim=1)
            x_dec_cat_fine = logit2one_hot(y_dec_fine_logits)
            if model.with_final_action:
                x_dec_cat_fine = x_dec_cat_fine[..., :-1]
            x_dec_cat_fine = x_dec_cat_fine * float(acc_fine_proportion > 0.0)
            if not disable_parent_input:
                x_dec_cat_fine = torch.cat([x_dec_cat_coarse, x_dec_cat_fine], dim=-1)
            x_dec_num_fine = torch.tensor([[acc_fine_proportion]], dtype=dtype, device=device)
            x_dec_fine = torch.cat([x_dec_cat_fine, x_dec_num_fine], dim=-1)
            coarse_exceed = len(predicted_coarse_actions) >= maximum_prediction_length
            fine_exceed = len(predicted_fine_actions) >= maximum_prediction_length
            if coarse_exceed and fine_exceed:
                break
            if coarse_exceed:
                if coarse_exceed_first_time:
                    coarse_exceed_first_time = False
                    total_coarse_length = len(predicted_coarse_actions)
                elif len(predicted_coarse_actions) > total_coarse_length:
                    predicted_coarse_steps = predicted_coarse_steps[:-1]
                    predicted_fine_steps = predicted_fine_steps[:-1]
                    break
    if model.with_final_action:
        fine_steps = [(None, None)] + predicted_fine_steps
        coarse_steps = predicted_coarse_steps[:1] + [(None, None)] + predicted_coarse_steps[1:]
        coarse_steps = maybe_rebalance_steps(coarse_steps, maximum_prediction_length)
        predicted_fine_steps, predicted_coarse_steps = fix_steps(fine_steps, coarse_steps)
        predicted_fine_actions = actions_from_steps(predicted_fine_steps)
        predicted_coarse_actions = actions_from_steps(predicted_coarse_steps)
    predicted_actions = predicted_fine_actions, predicted_coarse_actions
    predicted_steps = predicted_fine_steps, predicted_coarse_steps
    return predicted_actions, predicted_steps, d_fines