Exemplo n.º 1
0
def load_model(model_dir, env, ts=100):
    model_registrar = ModelRegistrar(model_dir, 'cpu')
    model_registrar.load_models(ts)
    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)

    mats = MATS(model_registrar, hyperparams, None, 'cpu')
    return mats, hyperparams
Exemplo n.º 2
0
def load_model(model_dir, env, ts=100):
    model_registrar = ModelRegistrar(model_dir, 'cpu')
    model_registrar.load_models(ts)
    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)

    trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu')

    trajectron.set_environment(env)
    trajectron.set_annealing_params()
    return trajectron, hyperparams
def load_model(model_dir, env, ts=99):
    model_registrar = ModelRegistrar(model_dir, 'cpu')
    model_registrar.load_models(ts)
    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)

    stg = SpatioTemporalGraphCVAEModel(model_registrar, hyperparams, None,
                                       'cuda:0')
    hyperparams['incl_robot_node'] = False

    stg.set_scene_graph(env)
    stg.set_annealing_params()
    return stg, hyperparams
Exemplo n.º 4
0
def load_model(model_dir, env, ts=100, weight=0.0, seed=None):
    model_registrar = ModelRegistrar(model_dir, 'cpu')

    prefix = 'w-{:.4f}-s-{:d}'.format(weight, seed)
    model_registrar.load_models(ts, prefix)

    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)

    trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu')

    trajectron.set_environment(env)
    trajectron.set_annealing_params()
    return trajectron, hyperparams
Exemplo n.º 5
0
def load_model(model_dir, env, ts=3999, device='cpu'):
    model_registrar = ModelRegistrar(model_dir, device)
    model_registrar.load_models(ts)
    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)

    hyperparams['map_enc_dropout'] = 0.0
    if 'incl_robot_node' not in hyperparams:
        hyperparams['incl_robot_node'] = False

    stg = Trajectron(model_registrar, hyperparams, None, device)
    stg.set_environment(env)
    stg.set_annealing_params()
    return stg, hyperparams
Exemplo n.º 6
0
def extract_our_and_sgan_preds(dataset_name, hyperparams, args, data_precondition='all'):
    print('At %s dataset' % dataset_name)

    ### SGAN LOADING ###
    sgan_model_path = os.path.join(args.sgan_models_path, '_'.join([dataset_name, '12', 'model.pt']))

    checkpoint = torch.load(sgan_model_path, map_location='cpu')
    generator = get_generator(checkpoint)
    _args = AttrDict(checkpoint['args'])
    path = get_dset_path(_args.dataset_name, args.sgan_dset_type)
    print('Evaluating', sgan_model_path, 'on', _args.dataset_name, args.sgan_dset_type)

    _, sgan_data_loader = data_loader(_args, path)

    ### OUR METHOD LOADING ###
    data_dir = '../sgan-dataset/data'
    eval_data_dict_name = '%s_test.pkl' % dataset_name
    log_dir = '../sgan-dataset/logs/%s' % dataset_name
    have_our_model = False
    if os.path.isdir(log_dir):
        have_our_model = True

        trained_model_dir = os.path.join(log_dir, get_our_model_dir(dataset_name))
        eval_data_path = os.path.join(data_dir, eval_data_dict_name)
        with open(eval_data_path, 'rb') as f:
            eval_data_dict = pickle.load(f, encoding='latin1')
        eval_dt = eval_data_dict['dt']
        print('Loaded evaluation data from %s, eval_dt = %.2f' % (eval_data_path, eval_dt))

        # Loading weights from the trained model.
        specific_hyperparams = get_model_hyperparams(args, dataset_name)
        model_registrar = ModelRegistrar(trained_model_dir, args.device)
        model_registrar.load_models(specific_hyperparams['best_iter'])

        for key in eval_data_dict['input_dict'].keys():
            if isinstance(key, STGNode):
                random_node = key
                break

        hyperparams['state_dim'] = eval_data_dict['input_dict'][random_node].shape[2]
        hyperparams['pred_dim'] = len(eval_data_dict['pred_indices'])
        hyperparams['pred_indices'] = eval_data_dict['pred_indices']
        hyperparams['dynamic_edges'] = args.dynamic_edges
        hyperparams['edge_state_combine_method'] = specific_hyperparams['edge_state_combine_method']
        hyperparams['edge_influence_combine_method'] = specific_hyperparams['edge_influence_combine_method']
        hyperparams['nodes_standardization'] = eval_data_dict['nodes_standardization']
        hyperparams['labels_standardization'] = eval_data_dict['labels_standardization']
        hyperparams['edge_radius'] = args.edge_radius

        eval_hyperparams = copy.deepcopy(hyperparams)
        eval_hyperparams['nodes_standardization'] = eval_data_dict["nodes_standardization"]
        eval_hyperparams['labels_standardization'] = eval_data_dict["labels_standardization"]

        kwargs_dict = {'dynamic_edges': hyperparams['dynamic_edges'],
                       'edge_state_combine_method': hyperparams['edge_state_combine_method'],
                       'edge_influence_combine_method': hyperparams['edge_influence_combine_method']}


        print('-------------------------')
        print('| EVALUATION PARAMETERS |')
        print('-------------------------')
        print('| checking: %s' % data_precondition)
        print('| device: %s' % args.device)
        print('| eval_device: %s' % args.eval_device)
        print('| edge_radius: %s' % hyperparams['edge_radius'])
        print('| EE state_combine_method: %s' % hyperparams['edge_state_combine_method'])
        print('| EIE scheme: %s' % hyperparams['edge_influence_combine_method'])
        print('| dynamic_edges: %s' % hyperparams['dynamic_edges'])
        print('| edge_addition_filter: %s' % args.edge_addition_filter)
        print('| edge_removal_filter: %s' % args.edge_removal_filter)
        print('| MHL: %s' % hyperparams['minimum_history_length'])
        print('| PH: %s' % hyperparams['prediction_horizon'])
        print('| # Samples: %s' % args.num_samples)
        print('| # Runs: %s' % args.num_runs)
        print('-------------------------')

        # It is important that eval_stg uses the same model_registrar as
        # the stg being trained, otherwise you're just repeatedly evaluating
        # randomly-initialized weights!
        eval_stg = SpatioTemporalGraphCVAEModel(None, model_registrar,
                                                eval_hyperparams, kwargs_dict,
                                                None, args.eval_device)
        print('Created evaluation STG model.')

        eval_agg_scene_graph = create_batch_scene_graph(eval_data_dict['input_dict'],
                                                        float(hyperparams['edge_radius']),
                                                        use_old_method=(args.dynamic_edges=='no'))
        print('Created aggregate evaluation scene graph.')

        if args.dynamic_edges == 'yes':
            eval_agg_scene_graph.compute_edge_scaling(args.edge_addition_filter, args.edge_removal_filter)
            eval_data_dict['input_dict']['edge_scaling_mask'] = eval_agg_scene_graph.edge_scaling_mask
            print('Computed edge scaling for the evaluation scene graph.')

        eval_stg.set_scene_graph(eval_agg_scene_graph)
        print('Set the aggregate scene graph.')

        eval_stg.set_annealing_params()

    print('About to begin evaluation computation for %s.' % dataset_name)
    with torch.no_grad():
        eval_inputs, _ = sample_inputs_and_labels(eval_data_dict, device=args.eval_device)

        sgan_preds_list = list()
        sgan_gt_list = list()
        our_preds_list = list()
        our_preds_most_likely_list = list()

        (obs_traj, pred_traj_gt, obs_traj_rel,
         seq_start_end, data_ids, t_predicts) = get_sgan_data_format(eval_inputs, what_to_check=data_precondition)

        num_runs = args.num_runs
        print('num_runs, seq_start_end.shape[0]', args.num_runs, seq_start_end.shape[0])
        if args.num_runs > seq_start_end.shape[0]:
            print('num_runs (%d) > seq_start_end.shape[0] (%d), reducing num_runs to match.' % (num_runs, seq_start_end.shape[0]))
            num_runs = seq_start_end.shape[0]

        samples_list = list()
        for _ in range(args.num_samples):
            pred_traj_fake_rel = generator(
                obs_traj, obs_traj_rel, seq_start_end
            )
            pred_traj_fake = relative_to_abs(
                pred_traj_fake_rel, obs_traj[-1]
            )

            samples_list.append(pred_traj_fake)

        random_scene_idxs = np.random.choice(seq_start_end.shape[0],
                                             size=(num_runs,),
                                             replace=False).astype(int)

        sgan_history = defaultdict(dict)
        for run in range(num_runs):
            random_scene_idx = random_scene_idxs[run]
            seq_idx_range = seq_start_end[random_scene_idx]

            agent_preds = dict()
            agent_gt = dict()
            for seq_agent in range(seq_idx_range[0], seq_idx_range[1]):
                agent_preds[seq_agent] = torch.stack([x[:, seq_agent] for x in samples_list], dim=0)
                agent_gt[seq_agent] = torch.unsqueeze(pred_traj_gt[:, seq_agent], dim=0)
                sgan_history[run][seq_agent] = obs_traj[:, seq_agent]

            sgan_preds_list.append(agent_preds)
            sgan_gt_list.append(agent_gt)

        print('Done running SGAN')

        if have_our_model:
            sgan_our_agent_map = dict()
            our_sgan_agent_map = dict()
            for run in range(num_runs):
                print('At our run number', run)
                random_scene_idx = random_scene_idxs[run]
                data_id = data_ids[random_scene_idx]
                t_predict = t_predicts[random_scene_idx] - 1

                curr_inputs = {k: v[[data_id]] for k, v in eval_inputs.items()}
                curr_inputs['traj_lengths'] = torch.tensor([t_predict])

                with torch.no_grad():
                    preds_dict_most_likely = eval_stg.predict(curr_inputs, hyperparams['prediction_horizon'], args.num_samples, most_likely=True)
                    preds_dict_full = eval_stg.predict(curr_inputs, hyperparams['prediction_horizon'], args.num_samples, most_likely=False)

                our_preds_most_likely_list.append(preds_dict_most_likely)
                our_preds_list.append(preds_dict_full)

                for node, value in curr_inputs.items():
                    if isinstance(node, STGNode) and np.any(value[0, t_predict]):
                        curr_prev = value[0, t_predict+1-8 : t_predict+1]
                        for seq_agent, sgan_val in sgan_history[run].items():
                            if torch.norm(curr_prev[:, :2] - sgan_val) < 1e-4:
                                sgan_our_agent_map['%d/%d' % (run, seq_agent)] = node
                                our_sgan_agent_map['%d/%s' % (run, str(node))] = '%d/%d' % (run, seq_agent)

            print('Done running Our Method')

        # Pruning values that aren't in either.
        for run in range(num_runs):
            agent_preds = sgan_preds_list[run]
            agent_gt = sgan_gt_list[run]

            new_agent_preds = dict()
            new_agent_gts = dict()
            for agent in agent_preds.keys():
                run_agent_key = '%d/%d' % (run, agent)
                if run_agent_key in sgan_our_agent_map:
                    new_agent_preds[sgan_our_agent_map[run_agent_key]] = agent_preds[agent]
                    new_agent_gts[sgan_our_agent_map[run_agent_key]] = agent_gt[agent]

            sgan_preds_list[run] = new_agent_preds
            sgan_gt_list[run] = new_agent_gts

        for run in range(num_runs):
            agent_preds_ml = our_preds_most_likely_list[run]
            agent_preds_full = our_preds_list[run]

            new_agent_preds = dict()
            new_agent_preds_full = dict()
            for node in [x for x in agent_preds_ml.keys() if x.endswith('/y')]:
                node_key_list = node.split('/')
                node_obj = STGNode(node_key_list[1], node_key_list[0])
                node_obj_key = '%d/%s' % (run, str(node_obj))
                if node_obj_key in our_sgan_agent_map:
                    new_agent_preds[node_obj] = agent_preds_ml[node]
                    new_agent_preds_full[node_obj] = agent_preds_full[node]

            our_preds_most_likely_list[run] = new_agent_preds
            our_preds_list[run] = new_agent_preds_full

        # Guaranteeing the number of agents are the same.
        for run in range(num_runs):
            assert list_compare(our_preds_most_likely_list[run].keys(), sgan_preds_list[run].keys())
            assert list_compare(our_preds_list[run].keys(), sgan_preds_list[run].keys())
            assert list_compare(our_preds_most_likely_list[run].keys(), our_preds_list[run].keys())
            assert list_compare(sgan_preds_list[run].keys(), sgan_gt_list[run].keys())

    return (our_preds_most_likely_list, our_preds_list,
            sgan_preds_list, sgan_gt_list, eval_inputs, eval_data_dict,
            data_ids, t_predicts, random_scene_idxs, num_runs)
Exemplo n.º 7
0
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r', encoding='utf-8') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node
    hyperparams['node_freq_mult_train'] = args.node_freq_mult_train
    hyperparams['node_freq_mult_eval'] = args.node_freq_mult_eval
    hyperparams['scene_freq_mult_train'] = args.scene_freq_mult_train
    hyperparams['scene_freq_mult_eval'] = args.scene_freq_mult_eval
    hyperparams['scene_freq_mult_viz'] = args.scene_freq_mult_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding
    hyperparams['use_map_encoding'] = args.map_encoding
    hyperparams['augment'] = args.augment
    hyperparams['override_attention_radius'] = args.override_attention_radius

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| batch_size: %d' % args.batch_size)
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % args.incl_robot_node)
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    log_writer = None
    model_dir = None
    if not args.debug:
        # Create the log and model directiory if they're not present.
        model_dir = os.path.join(
            args.log_dir,
            'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) +
            args.log_tag)
        pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

        # Save config to model directory
        with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
            json.dump(hyperparams, conf_json)

        log_writer = SummaryWriter(log_dir=model_dir)

    # Load training and evaluation environments and scenes
    train_scenes = []
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = dill.load(f, encoding='latin1')

    for attention_radius_override in args.override_attention_radius:
        node_type1, node_type2, attention_radius = attention_radius_override.split(
            ' ')
        train_env.attention_radius[(node_type1,
                                    node_type2)] = float(attention_radius)

    if train_env.robot_type is None and hyperparams['incl_robot_node']:
        train_env.robot_type = train_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in train_env.scenes:
            scene.add_robot_from_nodes(train_env.robot_type)

    train_scenes = train_env.scenes
    train_scenes_sample_probs = train_env.scenes_freq_mult_prop if args.scene_freq_mult_train else None

    train_dataset = EnvironmentDataset(
        train_env,
        hyperparams['state'],
        hyperparams['pred_state'],
        scene_freq_mult=hyperparams['scene_freq_mult_train'],
        node_freq_mult=hyperparams['node_freq_mult_train'],
        hyperparams=hyperparams,
        min_history_timesteps=hyperparams['minimum_history_length'],
        min_future_timesteps=hyperparams['prediction_horizon'],
        return_robot=not args.incl_robot_node)
    train_data_loader = dict()
    for node_type_data_set in train_dataset:
        node_type_dataloader = utils.data.DataLoader(
            node_type_data_set,
            collate_fn=collate,
            pin_memory=False if args.device is 'cpu' else True,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.preprocess_workers)
        train_data_loader[node_type_data_set.node_type] = node_type_dataloader

    print(f"Loaded training data from {train_data_path}")

    eval_scenes = []
    eval_scenes_sample_probs = None
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = dill.load(f, encoding='latin1')

        for attention_radius_override in args.override_attention_radius:
            node_type1, node_type2, attention_radius = attention_radius_override.split(
                ' ')
            eval_env.attention_radius[(node_type1,
                                       node_type2)] = float(attention_radius)

        if eval_env.robot_type is None and hyperparams['incl_robot_node']:
            eval_env.robot_type = eval_env.NodeType[
                0]  # TODO: Make more general, allow the user to specify?
            for scene in eval_env.scenes:
                scene.add_robot_from_nodes(eval_env.robot_type)

        eval_scenes = eval_env.scenes
        eval_scenes_sample_probs = eval_env.scenes_freq_mult_prop if args.scene_freq_mult_eval else None

        eval_dataset = EnvironmentDataset(
            eval_env,
            hyperparams['state'],
            hyperparams['pred_state'],
            scene_freq_mult=hyperparams['scene_freq_mult_eval'],
            node_freq_mult=hyperparams['node_freq_mult_eval'],
            hyperparams=hyperparams,
            min_history_timesteps=hyperparams['minimum_history_length'],
            min_future_timesteps=hyperparams['prediction_horizon'],
            return_robot=not args.incl_robot_node)
        eval_data_loader = dict()
        for node_type_data_set in eval_dataset:
            node_type_dataloader = utils.data.DataLoader(
                node_type_data_set,
                collate_fn=collate,
                pin_memory=False if args.eval_device is 'cpu' else True,
                batch_size=args.eval_batch_size,
                shuffle=True,
                num_workers=args.preprocess_workers)
            eval_data_loader[
                node_type_data_set.node_type] = node_type_dataloader

        print(f"Loaded evaluation data from {eval_data_path}")

    # Offline Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Training Scene {i}")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Evaluation Scene {i}")

    model_registrar = ModelRegistrar(model_dir, args.device)

    trajectron = Trajectron(model_registrar, hyperparams, log_writer,
                            args.device)

    trajectron.set_environment(train_env)
    trajectron.set_annealing_params()
    print('Created Training Model.')

    eval_trajectron = None
    if args.eval_every is not None or args.vis_every is not None:
        eval_trajectron = Trajectron(model_registrar, hyperparams, log_writer,
                                     args.eval_device)
        eval_trajectron.set_environment(eval_env)
        eval_trajectron.set_annealing_params()
    print('Created Evaluation Model.')

    optimizer = dict()
    lr_scheduler = dict()
    for node_type in train_env.NodeType:
        if node_type not in hyperparams['pred_state']:
            continue
        optimizer[node_type] = optim.Adam([{
            'params':
            model_registrar.get_all_but_name_match('map_encoder').parameters()
        }, {
            'params':
            model_registrar.get_name_match('map_encoder').parameters(),
            'lr':
            0.0008
        }],
                                          lr=hyperparams['learning_rate'])
        # Set Learning Rate
        if hyperparams['learning_rate_style'] == 'const':
            lr_scheduler[node_type] = optim.lr_scheduler.ExponentialLR(
                optimizer[node_type], gamma=1.0)
        elif hyperparams['learning_rate_style'] == 'exp':
            lr_scheduler[node_type] = optim.lr_scheduler.ExponentialLR(
                optimizer[node_type], gamma=hyperparams['learning_decay_rate'])

    #################################
    #           TRAINING            #
    #################################
    curr_iter_node_type = {
        node_type: 0
        for node_type in train_data_loader.keys()
    }
    for epoch in range(1, args.train_epochs + 1):
        model_registrar.to(args.device)
        train_dataset.augment = args.augment
        for node_type, data_loader in train_data_loader.items():
            curr_iter = curr_iter_node_type[node_type]
            pbar = tqdm(data_loader, ncols=80)
            for batch in pbar:
                trajectron.set_curr_iter(curr_iter)
                trajectron.step_annealers(node_type)
                optimizer[node_type].zero_grad()
                train_loss = trajectron.train_loss(batch, node_type)
                pbar.set_description(
                    f"Epoch {epoch}, {node_type} L: {train_loss.item():.2f}")
                train_loss.backward()
                # Clipping gradients.
                if hyperparams['grad_clip'] is not None:
                    nn.utils.clip_grad_value_(model_registrar.parameters(),
                                              hyperparams['grad_clip'])
                optimizer[node_type].step()

                # Stepping forward the learning rate scheduler and annealers.
                lr_scheduler[node_type].step()

                if not args.debug:
                    log_writer.add_scalar(f"{node_type}/train/learning_rate",
                                          lr_scheduler[node_type].get_lr()[0],
                                          curr_iter)
                    log_writer.add_scalar(f"{node_type}/train/loss",
                                          train_loss, curr_iter)

                curr_iter += 1
            curr_iter_node_type[node_type] = curr_iter
        train_dataset.augment = False
        if args.eval_every is not None or args.vis_every is not None:
            eval_trajectron.set_curr_iter(epoch)

        #################################
        #        VISUALIZATION          #
        #################################
        if args.vis_every is not None and not args.debug and epoch % args.vis_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict random timestep to plot for train data set
                if args.scene_freq_mult_viz:
                    scene = np.random.choice(train_scenes,
                                             p=train_scenes_sample_probs)
                else:
                    scene = np.random.choice(train_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = trajectron.predict(scene,
                                                 timestep,
                                                 ph,
                                                 z_mode=True,
                                                 gmm_mode=True,
                                                 all_z_sep=False,
                                                 full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/prediction', fig, epoch)

                model_registrar.to(args.eval_device)
                # Predict random timestep to plot for eval data set
                if args.scene_freq_mult_viz:
                    scene = np.random.choice(eval_scenes,
                                             p=eval_scenes_sample_probs)
                else:
                    scene = np.random.choice(eval_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = eval_trajectron.predict(scene,
                                                      timestep,
                                                      ph,
                                                      num_samples=20,
                                                      min_future_timesteps=ph,
                                                      z_mode=False,
                                                      full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction', fig, epoch)

                # Predict random timestep to plot for eval data set
                predictions = eval_trajectron.predict(scene,
                                                      timestep,
                                                      ph,
                                                      min_future_timesteps=ph,
                                                      z_mode=True,
                                                      gmm_mode=True,
                                                      all_z_sep=True,
                                                      full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_all_z', fig, epoch)

        #################################
        #           EVALUATION          #
        #################################
        if args.eval_every is not None and not args.debug and epoch % args.eval_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            model_registrar.to(args.eval_device)
            with torch.no_grad():
                # Calculate evaluation loss
                for node_type, data_loader in eval_data_loader.items():
                    eval_loss = []
                    print(
                        f"Starting Evaluation @ epoch {epoch} for node type: {node_type}"
                    )
                    pbar = tqdm(data_loader, ncols=80)
                    for batch in pbar:
                        eval_loss_node_type = eval_trajectron.eval_loss(
                            batch, node_type)
                        pbar.set_description(
                            f"Epoch {epoch}, {node_type} L: {eval_loss_node_type.item():.2f}"
                        )
                        eval_loss.append(
                            {node_type: {
                                'nll': [eval_loss_node_type]
                            }})
                        del batch

                    evaluation.log_batch_errors(eval_loss, log_writer,
                                                f"{node_type}/eval_loss",
                                                epoch)

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                for scene in tqdm(eval_scenes,
                                  desc='Sample Evaluation',
                                  ncols=80):
                    timesteps = scene.sample_timesteps(args.eval_batch_size)

                    predictions = eval_trajectron.predict(
                        scene,
                        timesteps,
                        ph,
                        num_samples=50,
                        min_future_timesteps=ph,
                        full_dist=False)

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(eval_batch_errors,
                                            log_writer,
                                            'eval',
                                            epoch,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict maximum likelihood batch timesteps for evaluation dataset evaluation
                eval_batch_errors_ml = []
                for scene in tqdm(eval_scenes, desc='MM Evaluation', ncols=80):
                    timesteps = scene.sample_timesteps(scene.timesteps)

                    predictions = eval_trajectron.predict(
                        scene,
                        timesteps,
                        ph,
                        num_samples=1,
                        min_future_timesteps=ph,
                        z_mode=True,
                        gmm_mode=True,
                        full_dist=False)

                    eval_batch_errors_ml.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            map=scene.map,
                            node_type_enum=eval_env.NodeType,
                            kde=False))

                evaluation.log_batch_errors(eval_batch_errors_ml, log_writer,
                                            'eval/ml', epoch)

        if args.save_every is not None and args.debug is False and epoch % args.save_every == 0:
            model_registrar.save_models(epoch)
Exemplo n.º 8
0
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_radius'] = args.edge_radius
    hyperparams['use_map_encoding'] = args.use_map_encoding
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| iterations: %d' % args.num_iters)
    print('| batch_size: %d' % args.batch_size)
    print('| batch_multiplier: %d' % args.batch_multiplier)
    print('| effective batch size: %d (= %d * %d)' %
          (args.batch_size * args.batch_multiplier, args.batch_size,
           args.batch_multiplier))
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| edge_radius: %s' % args.edge_radius)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % args.incl_robot_node)
    print('| map encoding: %s' % args.use_map_encoding)
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    # Create the log and model directiory if they're not present.
    model_dir = os.path.join(
        args.log_dir, 'models_' +
        time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag)
    pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

    # Save config to model directory
    with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
        json.dump(hyperparams, conf_json)

    log_writer = SummaryWriter(log_dir=model_dir)

    train_scenes = []
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = pickle.load(f, encoding='latin1')
    train_scenes = train_env.scenes
    print('Loaded training data from %s' % (train_data_path, ))

    eval_scenes = []
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = pickle.load(f, encoding='latin1')
        eval_scenes = eval_env.scenes
        print('Loaded evaluation data from %s' % (eval_data_path, ))

    # Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['state'],
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Scene {i}")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['state'],
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Scene {i}")

    model_registrar = ModelRegistrar(model_dir, args.device)

    # We use pre trained weights for the map CNN
    if args.use_map_encoding:
        inf_encoder_registrar = os.path.join(
            args.log_dir, 'weight_trans/model_registrar-1499.pt')
        model_dict = torch.load(inf_encoder_registrar,
                                map_location=args.device)

        for key in model_dict.keys():
            if 'map_encoder' in key:
                model_registrar.model_dict[key] = model_dict[key]
                assert model_registrar.get_model(key) is model_dict[key]

    stg = SpatioTemporalGraphCVAEModel(model_registrar, hyperparams,
                                       log_writer, args.device)
    stg.set_scene_graph(train_env)
    stg.set_annealing_params()
    print('Created training STG model.')

    eval_stg = None
    if args.eval_every is not None or args.vis_ervery is not None:
        eval_stg = SpatioTemporalGraphCVAEModel(model_registrar, hyperparams,
                                                log_writer, args.device)
        eval_stg.set_scene_graph(eval_env)
        eval_stg.set_annealing_params()  # TODO Check if necessary
    if hyperparams['learning_rate_style'] == 'const':
        optimizer = optim.Adam(model_registrar.parameters(),
                               lr=hyperparams['learning_rate'])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif hyperparams['learning_rate_style'] == 'exp':
        optimizer = optim.Adam(model_registrar.parameters(),
                               lr=hyperparams['learning_rate'])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=hyperparams['learning_decay_rate'])
    elif hyperparams['learning_rate_style'] == 'triangle':
        optimizer = optim.Adam(model_registrar.parameters(), lr=1.0)
        clr = cyclical_lr(100,
                          min_lr=hyperparams['min_learning_rate'],
                          max_lr=hyperparams['learning_rate'],
                          decay=hyperparams['learning_decay_rate'])
        lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, [clr])

    print_training_header(newline_start=True)
    for curr_iter in range(args.num_iters):
        # Necessary because we flip the weights contained between GPU and CPU sometimes.
        model_registrar.to(args.device)

        # Setting the current iterator value for internal logging.
        stg.set_curr_iter(curr_iter)
        if args.vis_every is not None:
            eval_stg.set_curr_iter(curr_iter)

        # Stepping forward the learning rate scheduler and annealers.
        lr_scheduler.step()
        log_writer.add_scalar('train/learning_rate',
                              lr_scheduler.get_lr()[0], curr_iter)
        stg.step_annealers()

        # Zeroing gradients for the upcoming iteration.
        optimizer.zero_grad()
        train_losses = dict()
        for node_type in train_env.NodeType:
            train_losses[node_type] = []
        for scene in np.random.choice(train_scenes, 10):
            for mb_num in range(args.batch_multiplier):
                # Obtaining the batch's training loss.
                timesteps = scene.sample_timesteps(hyperparams['batch_size'])

                # Compute the training loss.
                train_loss_by_type = stg.train_loss(
                    scene, timesteps, max_nodes=hyperparams['batch_size'])
                for node_type, train_loss in train_loss_by_type.items():
                    if train_loss is not None:
                        train_loss = train_loss / (args.batch_multiplier * 10)
                        train_losses[node_type].append(train_loss.item())

                        # Calculating gradients.
                        train_loss.backward()

        # Print training information. Also, no newline here. It's added in at a later line.
        print('{:9} | '.format(curr_iter), end='', flush=True)
        for node_type in train_env.NodeType:
            print('{}:{:10} | '.format(node_type.name[0],
                                       '%.2f' % sum(train_losses[node_type])),
                  end='',
                  flush=True)

        for node_type in train_env.NodeType:
            if len(train_losses[node_type]) > 0:
                log_writer.add_histogram(
                    f"{node_type.name}/train/minibatch_losses",
                    np.asarray(train_losses[node_type]), curr_iter)
                log_writer.add_scalar(f"{node_type.name}/train/loss",
                                      sum(train_losses[node_type]), curr_iter)

        # Clipping gradients.
        if hyperparams['grad_clip'] is not None:
            nn.utils.clip_grad_value_(model_registrar.parameters(),
                                      hyperparams['grad_clip'])

        # Performing a gradient step.
        optimizer.step()

        del train_loss  # TODO Necessary?

        if args.vis_every is not None and (curr_iter +
                                           1) % args.vis_every == 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict random timestep to plot for train data set
                scene = np.random.choice(train_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = stg.predict(scene,
                                          timestep,
                                          ph,
                                          num_samples_z=100,
                                          most_likely_z=False,
                                          all_z=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/prediction', fig, curr_iter)

                # Predict random timestep to plot for eval data set
                scene = np.random.choice(eval_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = eval_stg.predict(scene,
                                               timestep,
                                               ph,
                                               num_samples_z=100,
                                               most_likely_z=False,
                                               all_z=False,
                                               max_nodes=4 *
                                               args.eval_batch_size)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction', fig, curr_iter)

                # Plot predicted timestep for random scene in map
                fig, ax = plt.subplots(figsize=(15, 15))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph,
                                                   map=scene.map['PLOT'])
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_map', fig, curr_iter)

                # Predict random timestep to plot for eval data set
                predictions = eval_stg.predict(scene,
                                               timestep,
                                               ph,
                                               num_samples_gmm=50,
                                               most_likely_z=False,
                                               all_z=True,
                                               max_nodes=4 *
                                               args.eval_batch_size)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_all_z', fig, curr_iter)

        if args.eval_every is not None and (curr_iter +
                                            1) % args.eval_every == 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict batch timesteps for training dataset evaluation
                train_batch_errors = []
                max_scenes = np.min([len(train_scenes), 5])
                for scene in np.random.choice(train_scenes, max_scenes):
                    timesteps = scene.sample_timesteps(args.eval_batch_size)
                    predictions = stg.predict(scene,
                                              timesteps,
                                              ph,
                                              num_samples_z=100,
                                              min_future_timesteps=ph,
                                              max_nodes=4 *
                                              args.eval_batch_size)

                    train_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=train_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(train_batch_errors,
                                            log_writer,
                                            'train',
                                            curr_iter,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                for scene in eval_scenes:
                    timesteps = scene.sample_timesteps(args.eval_batch_size)

                    predictions = eval_stg.predict(scene,
                                                   timesteps,
                                                   ph,
                                                   num_samples_z=100,
                                                   min_future_timesteps=ph,
                                                   max_nodes=4 *
                                                   args.eval_batch_size)

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(eval_batch_errors,
                                            log_writer,
                                            'eval',
                                            curr_iter,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict maximum likelihood batch timesteps for evaluation dataset evaluation
                eval_batch_errors_ml = []
                for scene in eval_scenes:
                    timesteps = scene.sample_timesteps(scene.timesteps)

                    predictions = eval_stg.predict(scene,
                                                   timesteps,
                                                   ph,
                                                   num_samples_z=1,
                                                   min_future_timesteps=ph,
                                                   most_likely_z=True,
                                                   most_likely_gmm=True)

                    eval_batch_errors_ml.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            map=scene.map,
                            node_type_enum=eval_env.NodeType,
                            kde=False))

                evaluation.log_batch_errors(eval_batch_errors_ml, log_writer,
                                            'eval/ml', curr_iter)

                eval_loss = []
                max_scenes = np.min([len(eval_scenes), 25])
                for scene in np.random.choice(eval_scenes, max_scenes):
                    eval_loss.append(eval_stg.eval_loss(scene, timesteps))

                evaluation.log_batch_errors(eval_loss, log_writer, 'eval/loss',
                                            curr_iter)

        else:
            print('{:15} | {:10} | {:14}'.format('', '', ''),
                  end='',
                  flush=True)

        # Here's the newline that ends the current training information printing.
        print('')

        if args.save_every is not None and (curr_iter +
                                            1) % args.save_every == 0:
            model_registrar.save_models(curr_iter)
            print_training_header()
Exemplo n.º 9
0
def main():
    output_save_dir = 'pred_figs/%s_dyn_edges' % args.dynamic_edges
    pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)

    with open(os.path.join(args.data_dir, args.test_data_dict), 'rb') as f:
        test_data_dict = pickle.load(f, encoding='latin1')
    
    # Getting the natural time delta for this dataset.
    eval_dt = test_data_dict['dt']

    for node in test_data_dict['nodes_standardization']:
        for key in test_data_dict['nodes_standardization'][node]:
            test_data_dict['nodes_standardization'][node][key] = torch.from_numpy(test_data_dict['nodes_standardization'][node][key]).float().to(args.device)

    for node in test_data_dict['labels_standardization']:
        for key in test_data_dict['labels_standardization'][node]:
            test_data_dict['labels_standardization'][node][key] = torch.from_numpy(test_data_dict['labels_standardization'][node][key]).float().to(args.device)

    # robot_node = stg_node.STGNode('Al Horford', 'HomeC')
    # max_speed = 40.76

    robot_node = stg_node.STGNode('0', 'Pedestrian')
    max_speed = 12.422222

    # Initial memory usage
    print('%.2f MBs of RAM initially used.' % (memInUse()*1000.))

    # Loading weights from the trained model.
    model_registrar = ModelRegistrar(args.trained_model_dir, args.device)
    model_registrar.load_models(1999)

    hyperparams['state_dim'] = test_data_dict['input_dict'][robot_node].shape[2]
    hyperparams['pred_dim'] = len(test_data_dict['pred_indices'])
    hyperparams['pred_indices'] = test_data_dict['pred_indices']
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams['edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['nodes_standardization'] = test_data_dict['nodes_standardization']
    hyperparams['labels_standardization'] = test_data_dict['labels_standardization']
    hyperparams['edge_radius'] = args.edge_radius

    kwargs_dict = {'dynamic_edges': hyperparams['dynamic_edges'],
                   'edge_state_combine_method': hyperparams['edge_state_combine_method'],
                   'edge_influence_combine_method': hyperparams['edge_influence_combine_method'],
                   'edge_addition_filter': args.edge_addition_filter,
                   'edge_removal_filter': args.edge_removal_filter}

    online_stg = OnlineSpatioTemporalGraphCVAEModel(robot_node, model_registrar, 
                                                    hyperparams, kwargs_dict, 
                                                    args.device)

    data_id = 11
    init_scene_dict = dict()
    for node, traj_data in test_data_dict['input_dict'].items():
        if isinstance(node, stg_node.STGNode):
            init_scene_dict[str(node)] = traj_data[data_id, 0, :2]

    init_scene_graph = Scene(init_scene_dict).get_graph(args.edge_radius)
    online_stg.set_scene_graph(init_scene_graph)

    perf_dict = {'time': [0], 
                 'runtime': [np.nan],
                 'frequency': [np.nan],
                 'nodes': [len(online_stg.scene_graph.active_nodes)], 
                 'edges': [online_stg.scene_graph.num_edges], 
                 'mem_MB': [memInUse()*1000.]}
    print("At t=0, have %d nodes, %d edges which uses %.2f MBs of RAM." % (
            perf_dict['nodes'][0], perf_dict['edges'][0], perf_dict['mem_MB'][0])
         )

    for curr_timestep in range(1, test_data_dict['input_dict']['traj_lengths'][data_id] - args.prediction_horizon + 1):
        robot_future = get_robot_future(robot_node, curr_timestep, 
                                        data_id, test_data_dict, 
                                        args.prediction_horizon)

        new_pos_dict, new_inputs_dict = get_inputs_dict(curr_timestep, data_id, 
                                                        test_data_dict)

        start = time.time()
        preds_dict = online_stg.incremental_forward(robot_future, new_pos_dict, new_inputs_dict, 
                                                    args.prediction_horizon, int(args.num_samples/2))
        end = time.time()

        if args.plot_online == 'yes':
            plot_utils.plot_online_prediction(preds_dict, new_inputs_dict, online_stg, 
                                              curr_timestep, robot_future, 
                                              dt=eval_dt, max_speed=max_speed,
                                              ylim=(2, 9), xlim=(-6, 17),
                                              dpi=150, figsize=(2.2*4, 4),
                                              edge_line_width=0.1, line_width=0.3,
                                              omit_names=True,
                                              save_at=os.path.join(output_save_dir, 'online_pred_%d.png' % curr_timestep))

        perf_dict['time'].append(curr_timestep)
        perf_dict['runtime'].append(end-start)
        perf_dict['frequency'].append(1./(end-start))
        perf_dict['nodes'].append(len(online_stg.scene_graph.active_nodes))
        perf_dict['edges'].append(online_stg.scene_graph.num_edges)
        perf_dict['mem_MB'].append(memInUse()*1000.)
        print("t=%d: took %.2f s (= %.2f Hz) and %d nodes, %d edges uses %.2f MBs of RAM." % (
                perf_dict['time'][-1], perf_dict['runtime'][-1], 
                perf_dict['frequency'][-1], perf_dict['nodes'][-1],
                perf_dict['edges'][-1], perf_dict['mem_MB'][-1])
             )

    plot_utils.plot_performance_metrics(perf_dict, output_save_dir)
Exemplo n.º 10
0
def main():
    model_dir = os.path.join(args.log_dir,
                             'models_14_Jan_2020_00_24_21eth_no_rob')

    # Load hyperparameters from json
    config_file = os.path.join(model_dir, args.conf)
    if not os.path.exists(config_file):
        raise ValueError('Config json not found!')
    with open(config_file, 'r') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node
    hyperparams['scene_batch_size'] = args.scene_batch_size
    hyperparams['node_resample_train'] = args.node_resample_train
    hyperparams['node_resample_eval'] = args.node_resample_eval
    hyperparams['scene_resample_train'] = args.scene_resample_train
    hyperparams['scene_resample_eval'] = args.scene_resample_eval
    hyperparams['scene_resample_viz'] = args.scene_resample_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding

    output_save_dir = os.path.join(model_dir, 'pred_figs')
    pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)

    eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
    with open(eval_data_path, 'rb') as f:
        eval_env = pickle.load(f, encoding='latin1')

    if eval_env.robot_type is None and hyperparams['incl_robot_node']:
        eval_env.robot_type = eval_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in eval_env.scenes:
            scene.add_robot_from_nodes(eval_env.robot_type)

    print('Loaded evaluation data from %s' % (eval_data_path, ))

    # Creating a dummy environment with a single scene that contains information about the world.
    # When using this code, feel free to use whichever scene index or initial timestep you wish.
    scene_idx = 0

    # You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
    # so that you can immediately start incremental inference from the 3rd timestep onwards.
    init_timestep = 1

    eval_scene = eval_env.scenes[scene_idx]
    online_env = create_online_env(eval_env, hyperparams, scene_idx,
                                   init_timestep)

    model_registrar = ModelRegistrar(model_dir, args.eval_device)
    model_registrar.load_models(iter_num=1999)

    trajectron = OnlineTrajectron(model_registrar, hyperparams,
                                  args.eval_device)

    # If you want to see what different robot futures do to the predictions, uncomment this line as well as
    # related "... += adjustment" lines below.
    # adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)

    # Here's how you'd incrementally run the model, e.g. with streaming data.
    trajectron.set_environment(online_env, init_timestep)

    for timestep in range(init_timestep + 1, eval_scene.timesteps):
        pos_dict = eval_scene.get_clipped_pos_dict(timestep,
                                                   hyperparams['state'])

        robot_present_and_future = None
        if eval_scene.robot is not None:
            robot_present_and_future = eval_scene.robot.get(
                np.array(
                    [timestep, timestep + hyperparams['prediction_horizon']]),
                hyperparams['state'][eval_scene.robot.type],
                padding=0.0)
            robot_present_and_future = np.stack(
                [robot_present_and_future, robot_present_and_future], axis=0)
            # robot_present_and_future += adjustment

        start = time.time()
        preds = trajectron.incremental_forward(
            pos_dict,
            prediction_horizon=12,
            num_samples=25,
            robot_present_and_future=robot_present_and_future)
        end = time.time()
        print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" %
              (timestep, end - start, 1. / (end - start), len(
                  trajectron.nodes), trajectron.scene_graph.get_num_edges()))

        detailed_preds_dict = dict()
        for node in eval_scene.nodes:
            if node in preds:
                detailed_preds_dict[node] = preds[node]

        batch_stats = evaluation.compute_batch_statistics(
            {timestep: detailed_preds_dict},
            eval_scene.dt,
            max_hl=hyperparams['maximum_history_length'],
            ph=hyperparams['prediction_horizon'],
            node_type_enum=online_env.NodeType,
            prune_ph_to_future=True)

        evaluation.print_batch_errors([batch_stats], 'eval', timestep)

        fig, ax = plt.subplots()
        vis.visualize_prediction(ax, {timestep: preds}, eval_scene.dt,
                                 hyperparams['maximum_history_length'],
                                 hyperparams['prediction_horizon'])

        if eval_scene.robot is not None:
            robot_for_plotting = eval_scene.robot.get(
                np.array(
                    [timestep, timestep + hyperparams['prediction_horizon']]),
                hyperparams['state'][eval_scene.robot.type])
            # robot_for_plotting += adjustment

            ax.plot(robot_for_plotting[1:, 1],
                    robot_for_plotting[1:, 0],
                    color='r',
                    linewidth=1.0,
                    alpha=1.0)

            # Current Node Position
            circle = plt.Circle(
                (robot_for_plotting[0, 1], robot_for_plotting[0, 0]),
                0.3,
                facecolor='r',
                edgecolor='k',
                lw=0.5,
                zorder=3)
            ax.add_artist(circle)

        fig.savefig(os.path.join(output_save_dir, f'pred_{timestep}.pdf'),
                    dpi=300)
        plt.close(fig)
Exemplo n.º 11
0
def main():
    output_save_dir = 'pred_figs/%s_%s_dyn_edges' % (
        'full' if args.full_preds else 'z_best', args.dynamic_edges)
    pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)

    if args.test_data_dict is None:
        args.test_data_dict = random.choice(
            ['eth', 'hotel', 'univ', 'zara1', 'zara2']) + '_test.pkl'

    with open(os.path.join(args.data_dir, args.test_data_dict), 'rb') as f:
        test_data_dict = pickle.load(f, encoding='latin1')

    # Getting the natural time delta for this dataset.
    eval_dt = test_data_dict['dt']

    for node in test_data_dict['nodes_standardization']:
        for key in test_data_dict['nodes_standardization'][node]:
            test_data_dict['nodes_standardization'][node][
                key] = torch.from_numpy(test_data_dict['nodes_standardization']
                                        [node][key]).float().to(args.device)

    for node in test_data_dict['labels_standardization']:
        for key in test_data_dict['labels_standardization'][node]:
            test_data_dict['labels_standardization'][node][
                key] = torch.from_numpy(
                    test_data_dict['labels_standardization'][node]
                    [key]).float().to(args.device)

    max_speed = 12.422222
    if args.incl_robot_node:
        robot_node = stg_node.STGNode('0', 'Pedestrian')
    else:
        robot_node = None

    # Initial memory usage
    init_mem_usage = memInUse() * 1000.
    print('%.2f MBs of RAM initially used.' % init_mem_usage)

    # Loading weights from the trained model.
    dataset_name = args.test_data_dict.split("_")[0]
    if args.trained_model_dir is None:
        args.trained_model_dir = os.path.join(
            '../sgan-dataset/logs', dataset_name,
            eval_utils.get_our_model_dir(dataset_name))

    if args.trained_model_iter is None:
        args.trained_model_iter = eval_utils.get_model_hyperparams(
            args, dataset_name)['best_iter']

    if args.edge_state_combine_method is None:
        args.edge_state_combine_method = eval_utils.get_model_hyperparams(
            args, dataset_name)['edge_state_combine_method']

    if args.edge_influence_combine_method is None:
        args.edge_influence_combine_method = eval_utils.get_model_hyperparams(
            args, dataset_name)['edge_influence_combine_method']

    model_registrar = ModelRegistrar(args.trained_model_dir, args.device)
    model_registrar.load_models(args.trained_model_iter)

    for key in test_data_dict['input_dict'].keys():
        if isinstance(key, stg_node.STGNode):
            random_node = key
            break

    hyperparams['state_dim'] = test_data_dict['input_dict'][random_node].shape[
        2]
    hyperparams['pred_dim'] = len(test_data_dict['pred_indices'])
    hyperparams['pred_indices'] = test_data_dict['pred_indices']
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['nodes_standardization'] = test_data_dict[
        'nodes_standardization']
    hyperparams['labels_standardization'] = test_data_dict[
        'labels_standardization']
    hyperparams['edge_radius'] = args.edge_radius

    kwargs_dict = {
        'dynamic_edges':
        hyperparams['dynamic_edges'],
        'edge_state_combine_method':
        hyperparams['edge_state_combine_method'],
        'edge_influence_combine_method':
        hyperparams['edge_influence_combine_method'],
        'edge_addition_filter':
        args.edge_addition_filter,
        'edge_removal_filter':
        args.edge_removal_filter
    }

    online_stg = OnlineSpatioTemporalGraphCVAEModel(robot_node,
                                                    model_registrar,
                                                    hyperparams, kwargs_dict,
                                                    args.device)

    data_id = random.randint(
        0, test_data_dict['input_dict'][random_node].shape[0] - 1)

    print('Looking at the %s sequence, data_id %d' % (dataset_name, data_id))

    init_scene_dict = dict()
    for node, traj_data in test_data_dict['input_dict'].items():
        if isinstance(node, stg_node.STGNode):
            init_scene_dict[str(node)] = traj_data[data_id, 0, :2]

    init_scene_graph = Scene(init_scene_dict).get_graph(args.edge_radius)
    online_stg.set_scene_graph(init_scene_graph)

    perf_dict = {
        'time': [0],
        'runtime': [np.nan],
        'frequency': [np.nan],
        'nodes': [len(online_stg.scene_graph.active_nodes)],
        'edges': [online_stg.scene_graph.num_edges],
        'mem_MB': [memInUse() * 1000. - init_mem_usage],
        'mse': [np.nan],
        'fse': [np.nan]
    }
    print(
        "At t=0, have %d nodes, %d edges which uses %.2f MBs of RAM." %
        (perf_dict['nodes'][0], perf_dict['edges'][0], perf_dict['mem_MB'][0]))

    # Keeps colors constant throughout the visualization.
    color_dict = defaultdict(dict)
    error_info_dict = {'output_limit': max_speed}
    start_idx = 1
    end_idx = test_data_dict['input_dict']['traj_lengths'][
        data_id] - args.prediction_horizon + 1
    for curr_timestep in range(start_idx, end_idx):
        robot_future = get_robot_future(robot_node, curr_timestep, data_id,
                                        test_data_dict,
                                        args.prediction_horizon)

        new_pos_dict, new_inputs_dict = get_inputs_dict(
            curr_timestep, data_id, test_data_dict)

        start = time.time()
        preds_dict = online_stg.incremental_forward(
            robot_future,
            new_pos_dict,
            new_inputs_dict,
            args.prediction_horizon,
            int(args.num_samples),
            most_likely=(not args.full_preds))
        end = time.time()

        mse_errs, fse_errs = eval_utils.compute_preds_dict_only_agg_errors(
            preds_dict, test_data_dict, data_id, curr_timestep,
            args.prediction_horizon, error_info_dict)

        if mse_errs is None and fse_errs is None:
            print('No agents in the scene, stopping!')
            break

        if args.plot_online == 'yes':
            plot_utils.plot_online_prediction(
                preds_dict,
                test_data_dict,
                data_id,
                args.prediction_horizon,
                new_inputs_dict,
                online_stg,
                curr_timestep,
                robot_future,
                dt=eval_dt,
                max_speed=max_speed,
                color_dict=color_dict,
                ylim=(0, 20),
                xlim=(0, 20),
                dpi=150,
                figsize=(4, 4),
                edge_line_width=0.1,
                line_width=0.5,
                omit_names=True,
                save_at=os.path.join(output_save_dir,
                                     'online_pred_%d.png' % curr_timestep))

        perf_dict['time'].append(curr_timestep)
        perf_dict['runtime'].append(end - start)
        perf_dict['frequency'].append(1. / (end - start))
        perf_dict['nodes'].append(len(online_stg.scene_graph.active_nodes))
        perf_dict['edges'].append(online_stg.scene_graph.num_edges)
        perf_dict['mem_MB'].append(memInUse() * 1000. - init_mem_usage)
        perf_dict['mse'].append(mse_errs)
        perf_dict['fse'].append(fse_errs)
        print(
            "t=%d: took %.2f s (= %.2f Hz) w/ MSE %.2f and FSE %.2f and %d nodes, %d edges uses %.2f MBs of RAM."
            % (perf_dict['time'][-1], perf_dict['runtime'][-1],
               perf_dict['frequency'][-1], torch.mean(perf_dict['mse'][-1]),
               torch.mean(perf_dict['fse'][-1]), perf_dict['nodes'][-1],
               perf_dict['edges'][-1], perf_dict['mem_MB'][-1]))

    if curr_timestep != start_idx:
        plot_utils.plot_performance_metrics(
            perf_dict, output_save_dir, hyperparams['minimum_history_length'],
            hyperparams['prediction_horizon'])
Exemplo n.º 12
0
def main():
    results_dict = {
        'data_precondition': list(),
        'dataset': list(),
        'method': list(),
        'runtime': list(),
        'num_samples': list(),
        'num_agents': list()
    }
    data_precondition = 'curr'
    for dataset_name in ['eth', 'hotel', 'univ', 'zara1', 'zara2']:
        print('At %s dataset' % dataset_name)

        ### SGAN LOADING ###
        sgan_model_path = os.path.join(
            args.sgan_models_path, '_'.join([dataset_name, '12', 'model.pt']))

        checkpoint = torch.load(sgan_model_path, map_location='cpu')
        generator = eval_utils.get_generator(checkpoint)
        _args = AttrDict(checkpoint['args'])
        path = get_dset_path(_args.dataset_name, args.sgan_dset_type)
        print('Evaluating', sgan_model_path, 'on', _args.dataset_name,
              args.sgan_dset_type)

        _, sgan_data_loader = data_loader(_args, path)

        ### OUR METHOD LOADING ###
        data_dir = '../sgan-dataset/data'
        eval_data_dict_name = '%s_test.pkl' % dataset_name
        log_dir = '../sgan-dataset/logs/%s' % dataset_name

        trained_model_dir = os.path.join(
            log_dir, eval_utils.get_our_model_dir(dataset_name))
        eval_data_path = os.path.join(data_dir, eval_data_dict_name)
        with open(eval_data_path, 'rb') as f:
            eval_data_dict = pickle.load(f, encoding='latin1')
        eval_dt = eval_data_dict['dt']
        print('Loaded evaluation data from %s, eval_dt = %.2f' %
              (eval_data_path, eval_dt))

        # Loading weights from the trained model.
        specific_hyperparams = eval_utils.get_model_hyperparams(
            args, dataset_name)
        model_registrar = ModelRegistrar(trained_model_dir, args.device)
        model_registrar.load_models(specific_hyperparams['best_iter'])

        for key in eval_data_dict['input_dict'].keys():
            if isinstance(key, STGNode):
                random_node = key
                break

        hyperparams['state_dim'] = eval_data_dict['input_dict'][
            random_node].shape[2]
        hyperparams['pred_dim'] = len(eval_data_dict['pred_indices'])
        hyperparams['pred_indices'] = eval_data_dict['pred_indices']
        hyperparams['dynamic_edges'] = args.dynamic_edges
        hyperparams['edge_state_combine_method'] = specific_hyperparams[
            'edge_state_combine_method']
        hyperparams['edge_influence_combine_method'] = specific_hyperparams[
            'edge_influence_combine_method']
        hyperparams['nodes_standardization'] = eval_data_dict[
            'nodes_standardization']
        hyperparams['labels_standardization'] = eval_data_dict[
            'labels_standardization']
        hyperparams['edge_radius'] = args.edge_radius

        eval_hyperparams = copy.deepcopy(hyperparams)
        eval_hyperparams['nodes_standardization'] = eval_data_dict[
            "nodes_standardization"]
        eval_hyperparams['labels_standardization'] = eval_data_dict[
            "labels_standardization"]

        kwargs_dict = {
            'dynamic_edges':
            hyperparams['dynamic_edges'],
            'edge_state_combine_method':
            hyperparams['edge_state_combine_method'],
            'edge_influence_combine_method':
            hyperparams['edge_influence_combine_method'],
            'edge_addition_filter':
            args.edge_addition_filter,
            'edge_removal_filter':
            args.edge_removal_filter
        }

        print('-------------------------')
        print('| EVALUATION PARAMETERS |')
        print('-------------------------')
        print('| checking: %s' % data_precondition)
        print('| device: %s' % args.device)
        print('| eval_device: %s' % args.eval_device)
        print('| edge_radius: %s' % hyperparams['edge_radius'])
        print('| EE state_combine_method: %s' %
              hyperparams['edge_state_combine_method'])
        print('| EIE scheme: %s' %
              hyperparams['edge_influence_combine_method'])
        print('| dynamic_edges: %s' % hyperparams['dynamic_edges'])
        print('| edge_addition_filter: %s' % args.edge_addition_filter)
        print('| edge_removal_filter: %s' % args.edge_removal_filter)
        print('| MHL: %s' % hyperparams['minimum_history_length'])
        print('| PH: %s' % hyperparams['prediction_horizon'])
        print('| # Samples: %s' % args.num_samples)
        print('| # Runs: %s' % args.num_runs)
        print('-------------------------')

        eval_stg = OnlineSpatioTemporalGraphCVAEModel(None, model_registrar,
                                                      eval_hyperparams,
                                                      kwargs_dict,
                                                      args.eval_device)
        print('Created evaluation STG model.')

        print('About to begin evaluation computation for %s.' % dataset_name)
        with torch.no_grad():
            eval_inputs, _ = eval_utils.sample_inputs_and_labels(
                eval_data_dict, device=args.eval_device)

        (obs_traj, pred_traj_gt, obs_traj_rel, seq_start_end, data_ids,
         t_predicts) = eval_utils.get_sgan_data_format(
             eval_inputs, what_to_check=data_precondition)

        num_runs = args.num_runs
        print('num_runs, seq_start_end.shape[0]', args.num_runs,
              seq_start_end.shape[0])
        if args.num_runs > seq_start_end.shape[0]:
            print(
                'num_runs (%d) > seq_start_end.shape[0] (%d), reducing num_runs to match.'
                % (num_runs, seq_start_end.shape[0]))
            num_runs = seq_start_end.shape[0]

        random_scene_idxs = np.random.choice(seq_start_end.shape[0],
                                             size=(num_runs, ),
                                             replace=False).astype(int)

        for scene_idxs in random_scene_idxs:
            choice_list = seq_start_end[scene_idxs]

            overall_tic = time.time()
            for sample_num in range(args.num_samples):
                pred_traj_fake_rel = generator(obs_traj, obs_traj_rel,
                                               seq_start_end)
                pred_traj_fake = relative_to_abs(pred_traj_fake_rel,
                                                 obs_traj[-1])

            overall_toc = time.time()
            print('SGAN overall', overall_toc - overall_tic)
            results_dict['data_precondition'].append(data_precondition)
            results_dict['dataset'].append(dataset_name)
            results_dict['method'].append('sgan')
            results_dict['runtime'].append(overall_toc - overall_tic)
            results_dict['num_samples'].append(args.num_samples)
            results_dict['num_agents'].append(
                int(choice_list[1].item() - choice_list[0].item()))

        print('Done running SGAN')

        for node in eval_data_dict['nodes_standardization']:
            for key in eval_data_dict['nodes_standardization'][node]:
                eval_data_dict['nodes_standardization'][node][
                    key] = torch.from_numpy(
                        eval_data_dict['nodes_standardization'][node]
                        [key]).float().to(args.device)

        for node in eval_data_dict['labels_standardization']:
            for key in eval_data_dict['labels_standardization'][node]:
                eval_data_dict['labels_standardization'][node][
                    key] = torch.from_numpy(
                        eval_data_dict['labels_standardization'][node]
                        [key]).float().to(args.device)

        for run in range(num_runs):
            random_scene_idx = random_scene_idxs[run]
            data_id = data_ids[random_scene_idx]
            t_predict = t_predicts[random_scene_idx] - 1

            init_scene_dict = dict()
            for first_timestep in range(t_predict + 1):
                for node, traj_data in eval_data_dict['input_dict'].items():
                    if isinstance(node, STGNode):
                        init_pos = traj_data[data_id, first_timestep, :2]
                        if np.any(init_pos):
                            init_scene_dict[node] = init_pos

                if len(init_scene_dict) > 0:
                    break

            init_scene_graph = SceneGraph()
            init_scene_graph.create_from_scene_dict(init_scene_dict,
                                                    args.edge_radius)

            curr_inputs = {
                k: v[data_id, first_timestep:t_predict + 1]
                for k, v in eval_data_dict['input_dict'].items()
                if (isinstance(k, STGNode) and (
                    k in init_scene_graph.active_nodes))
            }
            curr_pos_inputs = {k: v[..., :2] for k, v in curr_inputs.items()}

            with torch.no_grad():
                overall_tic = time.time()
                preds_dict_most_likely = eval_stg.forward(
                    init_scene_graph,
                    curr_pos_inputs,
                    curr_inputs,
                    None,
                    hyperparams['prediction_horizon'],
                    args.num_samples,
                    most_likely=True)
                overall_toc = time.time()
                print('Our MLz overall', overall_toc - overall_tic)
                results_dict['data_precondition'].append(data_precondition)
                results_dict['dataset'].append(dataset_name)
                results_dict['method'].append('our_most_likely')
                results_dict['runtime'].append(overall_toc - overall_tic)
                results_dict['num_samples'].append(args.num_samples)
                results_dict['num_agents'].append(len(init_scene_dict))

                overall_tic = time.time()
                preds_dict_full = eval_stg.forward(
                    init_scene_graph,
                    curr_pos_inputs,
                    curr_inputs,
                    None,
                    hyperparams['prediction_horizon'],
                    args.num_samples,
                    most_likely=False)
                overall_toc = time.time()
                print('Our Full overall', overall_toc - overall_tic)
                results_dict['data_precondition'].append(data_precondition)
                results_dict['dataset'].append(dataset_name)
                results_dict['method'].append('our_full')
                results_dict['runtime'].append(overall_toc - overall_tic)
                results_dict['num_samples'].append(args.num_samples)
                results_dict['num_agents'].append(len(init_scene_dict))

        pd.DataFrame.from_dict(results_dict).to_csv(
            '../sgan-dataset/plots/data/%s_%s_runtimes.csv' %
            (data_precondition, dataset_name),
            index=False)
Exemplo n.º 13
0
def main():
    # Create the log and model directiory if they're not present.
    model_dir = os.path.join(
        args.log_dir,
        'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()))
    pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

    log_writer = SummaryWriter(log_dir=model_dir)

    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_data_dict = pickle.load(f, encoding='latin1')
    train_dt = train_data_dict['dt']
    print('Loaded training data from %s, train_dt = %.2f' %
          (train_data_path, train_dt))

    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_data_dict = pickle.load(f, encoding='latin1')
        eval_dt = eval_data_dict['dt']
        print('Loaded evaluation data from %s, eval_dt = %.2f' %
              (eval_data_path, eval_dt))

    if args.incl_robot_node:
        robot_node = stg_node.STGNode('0', 'Pedestrian')
    else:
        robot_node = None

    for key in train_data_dict['input_dict'].keys():
        if isinstance(key, stg_node.STGNode):
            random_node = key
            break

    model_registrar = ModelRegistrar(model_dir, args.device)
    hyperparams['state_dim'] = train_data_dict['input_dict'][
        random_node].shape[2]
    hyperparams['pred_dim'] = len(train_data_dict['pred_indices'])
    hyperparams['pred_indices'] = train_data_dict['pred_indices']
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['nodes_standardization'] = train_data_dict[
        'nodes_standardization']
    hyperparams['labels_standardization'] = train_data_dict[
        'labels_standardization']
    hyperparams['edge_radius'] = args.edge_radius

    if args.eval_every is not None:
        eval_hyperparams = copy.deepcopy(hyperparams)
        eval_hyperparams['nodes_standardization'] = eval_data_dict[
            "nodes_standardization"]
        eval_hyperparams['labels_standardization'] = eval_data_dict[
            "labels_standardization"]

    kwargs_dict = {
        'dynamic_edges':
        hyperparams['dynamic_edges'],
        'edge_state_combine_method':
        hyperparams['edge_state_combine_method'],
        'edge_influence_combine_method':
        hyperparams['edge_influence_combine_method']
    }

    stg = SpatioTemporalGraphCVAEModel(robot_node, model_registrar,
                                       hyperparams, kwargs_dict, None,
                                       args.device)
    print('Created training STG model.')

    if args.eval_every is not None:
        # It is important that eval_stg uses the same model_registrar as
        # the stg being trained, otherwise you're just repeatedly evaluating
        # randomly-initialized weights!
        eval_stg = SpatioTemporalGraphCVAEModel(robot_node, model_registrar,
                                                eval_hyperparams, kwargs_dict,
                                                None, args.eval_device)
        print('Created evaluation STG model.')

    # Create the aggregate scene_graph for all the data, allowing
    # for batching, just like the old one. Then, for speed tests
    # we'll show how much faster this method is than keeping the
    # full version. Can show graphs of forward inference time vs problem size
    # with two lines (using aggregate graph, using online-computed graph).
    agg_scene_graph = create_batch_scene_graph(
        train_data_dict['input_dict'],
        float(hyperparams['edge_radius']),
        use_old_method=(args.dynamic_edges == 'no'))
    print('Created aggregate training scene graph.')

    if args.dynamic_edges == 'yes':
        agg_scene_graph.compute_edge_scaling(args.edge_addition_filter,
                                             args.edge_removal_filter)
        train_data_dict['input_dict'][
            'edge_scaling_mask'] = agg_scene_graph.edge_scaling_mask
        print('Computed edge scaling for the training scene graph.')

    stg.set_scene_graph(agg_scene_graph)
    stg.set_annealing_params()

    if args.eval_every is not None:
        eval_agg_scene_graph = create_batch_scene_graph(
            eval_data_dict['input_dict'],
            float(hyperparams['edge_radius']),
            use_old_method=(args.dynamic_edges == 'no'))
        print('Created aggregate evaluation scene graph.')

        if args.dynamic_edges == 'yes':
            eval_agg_scene_graph.compute_edge_scaling(
                args.edge_addition_filter, args.edge_removal_filter)
            eval_data_dict['input_dict'][
                'edge_scaling_mask'] = eval_agg_scene_graph.edge_scaling_mask
            print('Computed edge scaling for the evaluation scene graph.')

        eval_stg.set_scene_graph(eval_agg_scene_graph)
        eval_stg.set_annealing_params()

    # model_registrar.print_model_names()
    optimizer = optim.Adam(model_registrar.parameters(),
                           lr=hyperparams['learning_rate'])
    lr_scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=hyperparams['learning_decay_rate'])

    # Keeping colors consistent throughout training.
    color_dict = defaultdict(dict)

    print_training_header(newline_start=True)
    for curr_iter in range(args.num_iters):
        # Necessary because we flip the weights contained between GPU and CPU sometimes.
        model_registrar.to(args.device)

        # Setting the current iterator value for internal logging.
        stg.set_curr_iter(curr_iter)

        # Stepping forward the learning rate scheduler and annealers.
        lr_scheduler.step()
        log_writer.add_scalar('dynstg/learning_rate',
                              lr_scheduler.get_lr()[0], curr_iter)
        stg.step_annealers()

        # Zeroing gradients for the upcoming iteration.
        optimizer.zero_grad()

        train_losses = list()
        for mb_num in range(args.batch_multiplier):
            # Obtaining the batch's training loss.
            train_inputs, train_labels = sample_inputs_and_labels(
                train_data_dict, batch_size=hyperparams['batch_size'])

            # Compute the training loss.
            train_loss = stg.train_loss(
                train_inputs, train_labels,
                hyperparams['prediction_horizon']) / args.batch_multiplier
            train_losses.append(train_loss.item())

            # Calculating gradients.
            train_loss.backward()

        # Print training information. Also, no newline here. It's added in at a later line.
        iter_train_loss = sum(train_losses)
        print('{:9} | {:10} | '.format(curr_iter, '%.2f' % iter_train_loss),
              end='',
              flush=True)

        log_writer.add_histogram('dynstg/train_minibatch_losses',
                                 np.asarray(train_losses), curr_iter)
        log_writer.add_scalar('dynstg/train_loss', iter_train_loss, curr_iter)

        # Clipping gradients.
        if hyperparams['grad_clip'] is not None:
            nn.utils.clip_grad_value_(model_registrar.parameters(),
                                      hyperparams['grad_clip'])

        # # Logging gradient norms.
        # len_prefix = len('model_dict.')
        # for name, param in model_registrar.named_parameters():
        #     if param.grad is None:
        #         # print(name, 'grad is None')
        #         continue

        #     log_writer.add_scalar('gradient_norms/' + name[len_prefix:],
        #                           param.grad.norm(),
        #                           curr_iter)

        # Performing a gradient step.
        optimizer.step()

        # Freeing up memory.
        del train_loss

        if args.eval_every is not None and (curr_iter +
                                            1) % args.eval_every == 0:
            with torch.no_grad():
                # First plotting training predictions.
                pred_fig = plot_utils.plot_predictions_during_training(
                    stg,
                    train_inputs,
                    hyperparams['prediction_horizon'],
                    num_samples=100,
                    dt=train_dt,
                    max_speed=max_speed,
                    color_dict=color_dict,
                    most_likely=True)
                log_writer.add_figure('dynstg/train_prediction', pred_fig,
                                      curr_iter)

                train_mse_batch_errors, train_fse_batch_errors = eval_utils.compute_batch_statistics(
                    stg,
                    train_data_dict,
                    hyperparams['minimum_history_length'],
                    hyperparams['prediction_horizon'],
                    num_samples=100,
                    num_runs=100,
                    dt=train_dt,
                    max_speed=max_speed,
                    robot_node=robot_node)
                log_writer.add_histogram('dynstg/train_mse',
                                         train_mse_batch_errors, curr_iter)
                log_writer.add_histogram('dynstg/train_fse',
                                         train_fse_batch_errors, curr_iter)

                mse_boxplot_fig, fse_boxplot_fig = plot_utils.plot_boxplots_during_training(
                    train_mse_batch_errors, train_fse_batch_errors)
                log_writer.add_figure('dynstg/train_mse_boxplot',
                                      mse_boxplot_fig, curr_iter)
                log_writer.add_figure('dynstg/train_fse_boxplot',
                                      fse_boxplot_fig, curr_iter)

                log_writer.add_scalars(
                    'dynstg/train_sq_error', {
                        'mean_mse': torch.mean(train_mse_batch_errors),
                        'mean_fse': torch.mean(train_fse_batch_errors),
                        'median_mse': torch.median(train_mse_batch_errors),
                        'median_fse': torch.median(train_fse_batch_errors)
                    }, curr_iter)

                # Then computing evaluation values and predictions.
                model_registrar.to(args.eval_device)
                eval_stg.set_curr_iter(curr_iter)
                eval_inputs, eval_labels = sample_inputs_and_labels(
                    eval_data_dict,
                    device=args.eval_device,
                    batch_size=args.eval_batch_size)

                (eval_loss_q_is, eval_loss_p,
                 eval_loss_exact) = eval_stg.eval_loss(
                     eval_inputs, eval_labels,
                     hyperparams['prediction_horizon'])
                log_writer.add_scalars(
                    'dynstg/eval', {
                        'nll_q_is': eval_loss_q_is,
                        'nll_p': eval_loss_p,
                        'nll_exact': eval_loss_exact
                    }, curr_iter)

                pred_fig = plot_utils.plot_predictions_during_training(
                    eval_stg,
                    eval_inputs,
                    hyperparams['prediction_horizon'],
                    num_samples=100,
                    dt=eval_dt,
                    max_speed=max_speed,
                    color_dict=color_dict,
                    most_likely=True)
                log_writer.add_figure('dynstg/eval_prediction', pred_fig,
                                      curr_iter)

                eval_mse_batch_errors, eval_fse_batch_errors = eval_utils.compute_batch_statistics(
                    eval_stg,
                    eval_data_dict,
                    hyperparams['minimum_history_length'],
                    hyperparams['prediction_horizon'],
                    num_samples=100,
                    num_runs=100,
                    dt=eval_dt,
                    max_speed=max_speed,
                    robot_node=robot_node)
                log_writer.add_histogram('dynstg/eval_mse',
                                         eval_mse_batch_errors, curr_iter)
                log_writer.add_histogram('dynstg/eval_fse',
                                         eval_fse_batch_errors, curr_iter)

                mse_boxplot_fig, fse_boxplot_fig = plot_utils.plot_boxplots_during_training(
                    eval_mse_batch_errors, eval_fse_batch_errors)
                log_writer.add_figure('dynstg/eval_mse_boxplot',
                                      mse_boxplot_fig, curr_iter)
                log_writer.add_figure('dynstg/eval_fse_boxplot',
                                      fse_boxplot_fig, curr_iter)

                log_writer.add_scalars(
                    'dynstg/eval_sq_error', {
                        'mean_mse': torch.mean(eval_mse_batch_errors),
                        'mean_fse': torch.mean(eval_fse_batch_errors),
                        'median_mse': torch.median(eval_mse_batch_errors),
                        'median_fse': torch.median(eval_fse_batch_errors)
                    }, curr_iter)

                print('{:15} | {:10} | {:14}'.format(
                    '%.2f' % eval_loss_q_is.item(),
                    '%.2f' % eval_loss_p.item(),
                    '%.2f' % eval_loss_exact.item()),
                      end='',
                      flush=True)

                # Freeing up memory.
                del eval_loss_q_is
                del eval_loss_p
                del eval_loss_exact

        else:
            print('{:15} | {:10} | {:14}'.format('', '', ''),
                  end='',
                  flush=True)

        # Here's the newline that ends the current training information printing.
        print('')

        if args.save_every is not None and (curr_iter +
                                            1) % args.save_every == 0:
            model_registrar.save_models(curr_iter)
            print_training_header()
Exemplo n.º 14
0
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r', encoding='utf-8') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = not args.no_robot_node
    hyperparams['node_freq_mult_train'] = args.node_freq_mult_train
    hyperparams['node_freq_mult_eval'] = args.node_freq_mult_eval
    hyperparams['scene_freq_mult_train'] = args.scene_freq_mult_train
    hyperparams['scene_freq_mult_eval'] = args.scene_freq_mult_eval
    hyperparams['scene_freq_mult_viz'] = args.scene_freq_mult_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding
    hyperparams['use_map_encoding'] = args.map_encoding
    hyperparams['augment'] = args.augment
    hyperparams['override_attention_radius'] = args.override_attention_radius
    hyperparams['include_B'] = not args.no_B
    hyperparams['reg_B'] = args.reg_B
    hyperparams['reg_B_weight'] = args.reg_B_weight
    hyperparams['zero_R_rows'] = args.zero_R_rows
    hyperparams['reg_A_slew'] = args.reg_A_slew
    hyperparams['reg_A_slew_weight'] = args.reg_A_slew_weight

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| batch_size: %d' % args.batch_size)
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % (not args.no_robot_node))
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    log_writer = None
    model_dir = None
    if not args.debug:
        # Create the log and model directiory if they're not present.
        model_dir = os.path.join(
            args.log_dir,
            'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) +
            args.log_tag)
        pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

        # Save config to model directory
        with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
            json.dump(hyperparams, conf_json)

        log_writer = SummaryWriter(log_dir=model_dir)

    # Load training and evaluation environments and scenes
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = dill.load(f, encoding='latin1')

    for attention_radius_override in args.override_attention_radius:
        node_type1, node_type2, attention_radius = attention_radius_override.split(
            ' ')
        train_env.attention_radius[(node_type1,
                                    node_type2)] = float(attention_radius)

    if train_env.robot_type is None and hyperparams['incl_robot_node']:
        train_env.robot_type = train_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in train_env.scenes:
            scene.add_robot_from_nodes(
                train_env.robot_type,
                hyperparams=hyperparams,
                min_timesteps=hyperparams['minimum_history_length'] + 1 +
                hyperparams['prediction_horizon'])

    train_scenes = train_env.scenes
    train_dataset = EnvironmentDataset(
        train_env,
        hyperparams['state'],
        hyperparams['pred_state'],
        scene_freq_mult=hyperparams['scene_freq_mult_train'],
        node_freq_mult=hyperparams['node_freq_mult_train'],
        hyperparams=hyperparams,
        min_history_timesteps=hyperparams['minimum_history_length'],
        min_future_timesteps=hyperparams['prediction_horizon'])
    train_data_loader = utils.data.DataLoader(
        train_dataset.dataset,
        collate_fn=collate,
        pin_memory=False if args.device == torch.device('cpu') else True,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.preprocess_workers)

    print(f"Loaded training data from {train_data_path}")

    eval_scenes = []
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = dill.load(f, encoding='latin1')

        for attention_radius_override in args.override_attention_radius:
            node_type1, node_type2, attention_radius = attention_radius_override.split(
                ' ')
            eval_env.attention_radius[(node_type1,
                                       node_type2)] = float(attention_radius)

        if eval_env.robot_type is None and hyperparams['incl_robot_node']:
            eval_env.robot_type = eval_env.NodeType[
                0]  # TODO: Make more general, allow the user to specify?
            for scene in eval_env.scenes:
                scene.add_robot_from_nodes(
                    eval_env.robot_type,
                    hyperparams=hyperparams,
                    min_timesteps=hyperparams['minimum_history_length'] + 1 +
                    hyperparams['prediction_horizon'])

        eval_scenes = eval_env.scenes
        eval_dataset = EnvironmentDataset(
            eval_env,
            hyperparams['state'],
            hyperparams['pred_state'],
            scene_freq_mult=hyperparams['scene_freq_mult_eval'],
            node_freq_mult=hyperparams['node_freq_mult_eval'],
            hyperparams=hyperparams,
            min_history_timesteps=hyperparams['minimum_history_length'],
            min_future_timesteps=hyperparams['prediction_horizon'])
        eval_data_loader = utils.data.DataLoader(
            eval_dataset.dataset,
            collate_fn=collate,
            pin_memory=False
            if args.eval_device == torch.device('cpu') else True,
            batch_size=args.eval_batch_size,
            shuffle=True,
            num_workers=args.preprocess_workers)

        print(f"Loaded evaluation data from {eval_data_path}")

    # Offline Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
        print(f"Created Scene Graphs for Training Scenes")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
        print(f"Created Scene Graphs for Evaluation Scenes")

    model_registrar = ModelRegistrar(model_dir, args.device)

    mats = MATS(model_registrar, hyperparams, log_writer, args.device)

    mats.set_environment(train_env)
    mats.set_annealing_params()
    print('Created Training Model.')

    eval_mats = None
    if args.eval_every is not None or args.vis_every is not None:
        eval_mats = MATS(model_registrar, hyperparams, log_writer,
                         args.eval_device)
        eval_mats.set_environment(eval_env)
        eval_mats.set_annealing_params()
    print('Created Evaluation Model.')

    optimizer = optim.Adam(
        [{
            'params':
            model_registrar.get_all_but_name_match('map_encoder').parameters()
        }, {
            'params':
            model_registrar.get_name_match('map_encoder').parameters(),
            'lr': 0.0008
        }],
        lr=hyperparams['learning_rate'])
    # Set Learning Rate
    if hyperparams['learning_rate_style'] == 'const':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif hyperparams['learning_rate_style'] == 'exp':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=hyperparams['learning_decay_rate'])

    #################################
    #           TRAINING            #
    #################################
    curr_iter = 0
    for epoch in range(1, args.train_epochs + 1):
        model_registrar.to(args.device)
        train_dataset.augment = args.augment
        train_data_iterator = iter(train_data_loader)
        pbar = tqdm(total=len(train_data_loader), ncols=80)
        for _ in range(0, len(train_data_loader), args.batch_multiplier):
            mats.set_curr_iter(curr_iter)
            mats.step_annealers()

            # Zeroing gradients.
            optimizer.zero_grad()

            train_losses = list()
            for mb_num in range(args.batch_multiplier):
                try:
                    train_loss = mats.train_loss(
                        next(train_data_iterator),
                        include_B=hyperparams['include_B'],
                        reg_B=hyperparams['reg_B'],
                        zero_R_rows=hyperparams['zero_R_rows']
                    ) / args.batch_multiplier
                    train_losses.append(train_loss.item())
                    train_loss.backward()
                except StopIteration:
                    break

            pbar.update(args.batch_multiplier)
            pbar.set_description(f"Epoch {epoch} L: {sum(train_losses):.2f}")

            # Clipping gradients.
            if hyperparams['grad_clip'] is not None:
                nn.utils.clip_grad_value_(model_registrar.parameters(),
                                          hyperparams['grad_clip'])

            # Optimizer step.
            optimizer.step()

            # Stepping forward the learning rate scheduler and annealers.
            lr_scheduler.step()

            if not args.debug:
                log_writer.add_scalar(f"train/learning_rate",
                                      lr_scheduler.get_last_lr()[0], curr_iter)
                log_writer.add_scalar(f"train/loss", sum(train_losses),
                                      curr_iter)

            curr_iter += 1

        train_dataset.augment = False
        if args.eval_every is not None or args.vis_every is not None:
            eval_mats.set_curr_iter(epoch)

        #################################
        #        VISUALIZATION          #
        #################################
        if args.vis_every is not None and not args.debug and epoch % args.vis_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                index = train_dataset.dataset.index
                rand_elem = index[random.randrange(len(index))]
                scene, timestep = rand_elem[0], np.array([rand_elem[1]])
                pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, state_lengths_in_order = mats.predict(
                    scene,
                    timestep,
                    ph,
                    min_future_timesteps=ph,
                    include_B=hyperparams['include_B'],
                    zero_R_rows=hyperparams['zero_R_rows'])

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    pred_dists,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None,
                    robot_node=scene.robot)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/all_modes', fig, epoch)

                # Plot A, B, Q matrices.
                figs = visualization.visualize_mats(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'train/{"ABQ"[idx]}_mat', fig,
                                          epoch)

                # Plot most-likely A, B, Q matrices across time.
                figs = visualization.visualize_mats_time(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'train/ml_{"ABQ"[idx]}_mat', fig,
                                          epoch)

                model_registrar.to(args.eval_device)

                # Predict random timestep to plot for eval data set
                index = eval_dataset.dataset.index
                rand_elem = index[random.randrange(len(index))]
                scene, timestep = rand_elem[0], np.array([rand_elem[1]])

                pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, state_lengths_in_order = eval_mats.predict(
                    scene,
                    timestep,
                    ph,
                    min_future_timesteps=ph,
                    include_B=hyperparams['include_B'],
                    zero_R_rows=hyperparams['zero_R_rows'])

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    pred_dists,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None,
                    robot_node=scene.robot)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/all_modes', fig, epoch)

                # Plot A, B, Q matrices.
                figs = visualization.visualize_mats(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'eval/{"ABQ"[idx]}_mat', fig, epoch)

                # Plot most-likely A, B, Q matrices across time.
                figs = visualization.visualize_mats_time(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'eval/ml_{"ABQ"[idx]}_mat', fig,
                                          epoch)

        #################################
        #           EVALUATION          #
        #################################
        if args.eval_every is not None and not args.debug and epoch % args.eval_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            model_registrar.to(args.eval_device)
            with torch.no_grad():
                # Calculate evaluation loss
                eval_losses = []
                print(f"Starting Evaluation @ epoch {epoch}")
                pbar = tqdm(eval_data_loader, ncols=80)
                for batch in pbar:
                    eval_loss = eval_mats.eval_loss(
                        batch,
                        include_B=hyperparams['include_B'],
                        zero_R_rows=hyperparams['zero_R_rows'])
                    pbar.set_description(
                        f"Epoch {epoch} L: {eval_loss.item():.2f}")
                    eval_losses.append({'full_state': {'nll': [eval_loss]}})
                    del batch

                evaluation.log_batch_errors(eval_losses, log_writer, 'eval',
                                            epoch)

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                eval_mintopk_errors = []
                for scene, times in tqdm(
                        eval_dataset.dataset.scene_time_dict.items(),
                        desc='Evaluation',
                        ncols=80):
                    timesteps = np.random.choice(times, args.eval_batch_size)

                    pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, _ = eval_mats.predict(
                        scene,
                        timesteps,
                        ph=ph,
                        min_future_timesteps=ph,
                        include_B=hyperparams['include_B'],
                        zero_R_rows=hyperparams['zero_R_rows'])

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            pred_dists,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=None))  # scene.map))

                    eval_mintopk_errors.append(
                        evaluation.compute_mintopk_statistics(
                            pred_dists,
                            max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType))

                evaluation.log_batch_errors(eval_batch_errors, log_writer,
                                            'eval', epoch)

                evaluation.plot_mintopk_curves(eval_mintopk_errors, log_writer,
                                               'eval', epoch)

        if args.save_every is not None and args.debug is False and epoch % args.save_every == 0:
            model_registrar.save_models(epoch)