Beispiel #1
0
                predictions = eval_stg.predict(
                    scene,
                    timesteps,
                    ph,
                    num_samples=1,
                    min_future_timesteps=8,
                    z_mode=True,
                    gmm_mode=True,
                    full_dist=False)  # This will trigger grid sampling

                batch_error_dict = evaluation.compute_batch_statistics(
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    node_type_enum=env.NodeType,
                    map=None,
                    prune_ph_to_future=False,
                    kde=False)

                eval_ade_batch_errors = np.hstack(
                    (eval_ade_batch_errors,
                     batch_error_dict[args.node_type]['ade']))
                eval_fde_batch_errors = np.hstack(
                    (eval_fde_batch_errors,
                     batch_error_dict[args.node_type]['fde']))

            print(np.mean(eval_fde_batch_errors))
            pd.DataFrame({
                'value': eval_ade_batch_errors,
Beispiel #2
0
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r', encoding='utf-8') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node
    hyperparams['node_freq_mult_train'] = args.node_freq_mult_train
    hyperparams['node_freq_mult_eval'] = args.node_freq_mult_eval
    hyperparams['scene_freq_mult_train'] = args.scene_freq_mult_train
    hyperparams['scene_freq_mult_eval'] = args.scene_freq_mult_eval
    hyperparams['scene_freq_mult_viz'] = args.scene_freq_mult_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding
    hyperparams['use_map_encoding'] = args.map_encoding
    hyperparams['augment'] = args.augment
    hyperparams['override_attention_radius'] = args.override_attention_radius

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| batch_size: %d' % args.batch_size)
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % args.incl_robot_node)
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    log_writer = None
    model_dir = None
    if not args.debug:
        # Create the log and model directiory if they're not present.
        model_dir = os.path.join(
            args.log_dir,
            'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) +
            args.log_tag)
        pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

        # Save config to model directory
        with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
            json.dump(hyperparams, conf_json)

        log_writer = SummaryWriter(log_dir=model_dir)

    # Load training and evaluation environments and scenes
    train_scenes = []
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = dill.load(f, encoding='latin1')

    for attention_radius_override in args.override_attention_radius:
        node_type1, node_type2, attention_radius = attention_radius_override.split(
            ' ')
        train_env.attention_radius[(node_type1,
                                    node_type2)] = float(attention_radius)

    if train_env.robot_type is None and hyperparams['incl_robot_node']:
        train_env.robot_type = train_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in train_env.scenes:
            scene.add_robot_from_nodes(train_env.robot_type)

    train_scenes = train_env.scenes
    train_scenes_sample_probs = train_env.scenes_freq_mult_prop if args.scene_freq_mult_train else None

    train_dataset = EnvironmentDataset(
        train_env,
        hyperparams['state'],
        hyperparams['pred_state'],
        scene_freq_mult=hyperparams['scene_freq_mult_train'],
        node_freq_mult=hyperparams['node_freq_mult_train'],
        hyperparams=hyperparams,
        min_history_timesteps=hyperparams['minimum_history_length'],
        min_future_timesteps=hyperparams['prediction_horizon'],
        return_robot=not args.incl_robot_node)
    train_data_loader = dict()
    for node_type_data_set in train_dataset:
        node_type_dataloader = utils.data.DataLoader(
            node_type_data_set,
            collate_fn=collate,
            pin_memory=False if args.device is 'cpu' else True,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.preprocess_workers)
        train_data_loader[node_type_data_set.node_type] = node_type_dataloader

    print(f"Loaded training data from {train_data_path}")

    eval_scenes = []
    eval_scenes_sample_probs = None
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = dill.load(f, encoding='latin1')

        for attention_radius_override in args.override_attention_radius:
            node_type1, node_type2, attention_radius = attention_radius_override.split(
                ' ')
            eval_env.attention_radius[(node_type1,
                                       node_type2)] = float(attention_radius)

        if eval_env.robot_type is None and hyperparams['incl_robot_node']:
            eval_env.robot_type = eval_env.NodeType[
                0]  # TODO: Make more general, allow the user to specify?
            for scene in eval_env.scenes:
                scene.add_robot_from_nodes(eval_env.robot_type)

        eval_scenes = eval_env.scenes
        eval_scenes_sample_probs = eval_env.scenes_freq_mult_prop if args.scene_freq_mult_eval else None

        eval_dataset = EnvironmentDataset(
            eval_env,
            hyperparams['state'],
            hyperparams['pred_state'],
            scene_freq_mult=hyperparams['scene_freq_mult_eval'],
            node_freq_mult=hyperparams['node_freq_mult_eval'],
            hyperparams=hyperparams,
            min_history_timesteps=hyperparams['minimum_history_length'],
            min_future_timesteps=hyperparams['prediction_horizon'],
            return_robot=not args.incl_robot_node)
        eval_data_loader = dict()
        for node_type_data_set in eval_dataset:
            node_type_dataloader = utils.data.DataLoader(
                node_type_data_set,
                collate_fn=collate,
                pin_memory=False if args.eval_device is 'cpu' else True,
                batch_size=args.eval_batch_size,
                shuffle=True,
                num_workers=args.preprocess_workers)
            eval_data_loader[
                node_type_data_set.node_type] = node_type_dataloader

        print(f"Loaded evaluation data from {eval_data_path}")

    # Offline Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Training Scene {i}")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Evaluation Scene {i}")

    model_registrar = ModelRegistrar(model_dir, args.device)

    trajectron = Trajectron(model_registrar, hyperparams, log_writer,
                            args.device)

    trajectron.set_environment(train_env)
    trajectron.set_annealing_params()
    print('Created Training Model.')

    eval_trajectron = None
    if args.eval_every is not None or args.vis_every is not None:
        eval_trajectron = Trajectron(model_registrar, hyperparams, log_writer,
                                     args.eval_device)
        eval_trajectron.set_environment(eval_env)
        eval_trajectron.set_annealing_params()
    print('Created Evaluation Model.')

    optimizer = dict()
    lr_scheduler = dict()
    for node_type in train_env.NodeType:
        if node_type not in hyperparams['pred_state']:
            continue
        optimizer[node_type] = optim.Adam([{
            'params':
            model_registrar.get_all_but_name_match('map_encoder').parameters()
        }, {
            'params':
            model_registrar.get_name_match('map_encoder').parameters(),
            'lr':
            0.0008
        }],
                                          lr=hyperparams['learning_rate'])
        # Set Learning Rate
        if hyperparams['learning_rate_style'] == 'const':
            lr_scheduler[node_type] = optim.lr_scheduler.ExponentialLR(
                optimizer[node_type], gamma=1.0)
        elif hyperparams['learning_rate_style'] == 'exp':
            lr_scheduler[node_type] = optim.lr_scheduler.ExponentialLR(
                optimizer[node_type], gamma=hyperparams['learning_decay_rate'])

    #################################
    #           TRAINING            #
    #################################
    curr_iter_node_type = {
        node_type: 0
        for node_type in train_data_loader.keys()
    }
    for epoch in range(1, args.train_epochs + 1):
        model_registrar.to(args.device)
        train_dataset.augment = args.augment
        for node_type, data_loader in train_data_loader.items():
            curr_iter = curr_iter_node_type[node_type]
            pbar = tqdm(data_loader, ncols=80)
            for batch in pbar:
                trajectron.set_curr_iter(curr_iter)
                trajectron.step_annealers(node_type)
                optimizer[node_type].zero_grad()
                train_loss = trajectron.train_loss(batch, node_type)
                pbar.set_description(
                    f"Epoch {epoch}, {node_type} L: {train_loss.item():.2f}")
                train_loss.backward()
                # Clipping gradients.
                if hyperparams['grad_clip'] is not None:
                    nn.utils.clip_grad_value_(model_registrar.parameters(),
                                              hyperparams['grad_clip'])
                optimizer[node_type].step()

                # Stepping forward the learning rate scheduler and annealers.
                lr_scheduler[node_type].step()

                if not args.debug:
                    log_writer.add_scalar(f"{node_type}/train/learning_rate",
                                          lr_scheduler[node_type].get_lr()[0],
                                          curr_iter)
                    log_writer.add_scalar(f"{node_type}/train/loss",
                                          train_loss, curr_iter)

                curr_iter += 1
            curr_iter_node_type[node_type] = curr_iter
        train_dataset.augment = False
        if args.eval_every is not None or args.vis_every is not None:
            eval_trajectron.set_curr_iter(epoch)

        #################################
        #        VISUALIZATION          #
        #################################
        if args.vis_every is not None and not args.debug and epoch % args.vis_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict random timestep to plot for train data set
                if args.scene_freq_mult_viz:
                    scene = np.random.choice(train_scenes,
                                             p=train_scenes_sample_probs)
                else:
                    scene = np.random.choice(train_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = trajectron.predict(scene,
                                                 timestep,
                                                 ph,
                                                 z_mode=True,
                                                 gmm_mode=True,
                                                 all_z_sep=False,
                                                 full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/prediction', fig, epoch)

                model_registrar.to(args.eval_device)
                # Predict random timestep to plot for eval data set
                if args.scene_freq_mult_viz:
                    scene = np.random.choice(eval_scenes,
                                             p=eval_scenes_sample_probs)
                else:
                    scene = np.random.choice(eval_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = eval_trajectron.predict(scene,
                                                      timestep,
                                                      ph,
                                                      num_samples=20,
                                                      min_future_timesteps=ph,
                                                      z_mode=False,
                                                      full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction', fig, epoch)

                # Predict random timestep to plot for eval data set
                predictions = eval_trajectron.predict(scene,
                                                      timestep,
                                                      ph,
                                                      min_future_timesteps=ph,
                                                      z_mode=True,
                                                      gmm_mode=True,
                                                      all_z_sep=True,
                                                      full_dist=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    predictions,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_all_z', fig, epoch)

        #################################
        #           EVALUATION          #
        #################################
        if args.eval_every is not None and not args.debug and epoch % args.eval_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            model_registrar.to(args.eval_device)
            with torch.no_grad():
                # Calculate evaluation loss
                for node_type, data_loader in eval_data_loader.items():
                    eval_loss = []
                    print(
                        f"Starting Evaluation @ epoch {epoch} for node type: {node_type}"
                    )
                    pbar = tqdm(data_loader, ncols=80)
                    for batch in pbar:
                        eval_loss_node_type = eval_trajectron.eval_loss(
                            batch, node_type)
                        pbar.set_description(
                            f"Epoch {epoch}, {node_type} L: {eval_loss_node_type.item():.2f}"
                        )
                        eval_loss.append(
                            {node_type: {
                                'nll': [eval_loss_node_type]
                            }})
                        del batch

                    evaluation.log_batch_errors(eval_loss, log_writer,
                                                f"{node_type}/eval_loss",
                                                epoch)

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                for scene in tqdm(eval_scenes,
                                  desc='Sample Evaluation',
                                  ncols=80):
                    timesteps = scene.sample_timesteps(args.eval_batch_size)

                    predictions = eval_trajectron.predict(
                        scene,
                        timesteps,
                        ph,
                        num_samples=50,
                        min_future_timesteps=ph,
                        full_dist=False)

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(eval_batch_errors,
                                            log_writer,
                                            'eval',
                                            epoch,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict maximum likelihood batch timesteps for evaluation dataset evaluation
                eval_batch_errors_ml = []
                for scene in tqdm(eval_scenes, desc='MM Evaluation', ncols=80):
                    timesteps = scene.sample_timesteps(scene.timesteps)

                    predictions = eval_trajectron.predict(
                        scene,
                        timesteps,
                        ph,
                        num_samples=1,
                        min_future_timesteps=ph,
                        z_mode=True,
                        gmm_mode=True,
                        full_dist=False)

                    eval_batch_errors_ml.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            map=scene.map,
                            node_type_enum=eval_env.NodeType,
                            kde=False))

                evaluation.log_batch_errors(eval_batch_errors_ml, log_writer,
                                            'eval/ml', epoch)

        if args.save_every is not None and args.debug is False and epoch % args.save_every == 0:
            model_registrar.save_models(epoch)
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_radius'] = args.edge_radius
    hyperparams['use_map_encoding'] = args.use_map_encoding
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| iterations: %d' % args.num_iters)
    print('| batch_size: %d' % args.batch_size)
    print('| batch_multiplier: %d' % args.batch_multiplier)
    print('| effective batch size: %d (= %d * %d)' %
          (args.batch_size * args.batch_multiplier, args.batch_size,
           args.batch_multiplier))
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| edge_radius: %s' % args.edge_radius)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % args.incl_robot_node)
    print('| map encoding: %s' % args.use_map_encoding)
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    # Create the log and model directiory if they're not present.
    model_dir = os.path.join(
        args.log_dir, 'models_' +
        time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag)
    pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

    # Save config to model directory
    with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
        json.dump(hyperparams, conf_json)

    log_writer = SummaryWriter(log_dir=model_dir)

    train_scenes = []
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = pickle.load(f, encoding='latin1')
    train_scenes = train_env.scenes
    print('Loaded training data from %s' % (train_data_path, ))

    eval_scenes = []
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = pickle.load(f, encoding='latin1')
        eval_scenes = eval_env.scenes
        print('Loaded evaluation data from %s' % (eval_data_path, ))

    # Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['state'],
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Scene {i}")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['state'],
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
            print(f"Created Scene Graph for Scene {i}")

    model_registrar = ModelRegistrar(model_dir, args.device)

    # We use pre trained weights for the map CNN
    if args.use_map_encoding:
        inf_encoder_registrar = os.path.join(
            args.log_dir, 'weight_trans/model_registrar-1499.pt')
        model_dict = torch.load(inf_encoder_registrar,
                                map_location=args.device)

        for key in model_dict.keys():
            if 'map_encoder' in key:
                model_registrar.model_dict[key] = model_dict[key]
                assert model_registrar.get_model(key) is model_dict[key]

    stg = SpatioTemporalGraphCVAEModel(model_registrar, hyperparams,
                                       log_writer, args.device)
    stg.set_scene_graph(train_env)
    stg.set_annealing_params()
    print('Created training STG model.')

    eval_stg = None
    if args.eval_every is not None or args.vis_ervery is not None:
        eval_stg = SpatioTemporalGraphCVAEModel(model_registrar, hyperparams,
                                                log_writer, args.device)
        eval_stg.set_scene_graph(eval_env)
        eval_stg.set_annealing_params()  # TODO Check if necessary
    if hyperparams['learning_rate_style'] == 'const':
        optimizer = optim.Adam(model_registrar.parameters(),
                               lr=hyperparams['learning_rate'])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif hyperparams['learning_rate_style'] == 'exp':
        optimizer = optim.Adam(model_registrar.parameters(),
                               lr=hyperparams['learning_rate'])
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=hyperparams['learning_decay_rate'])
    elif hyperparams['learning_rate_style'] == 'triangle':
        optimizer = optim.Adam(model_registrar.parameters(), lr=1.0)
        clr = cyclical_lr(100,
                          min_lr=hyperparams['min_learning_rate'],
                          max_lr=hyperparams['learning_rate'],
                          decay=hyperparams['learning_decay_rate'])
        lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, [clr])

    print_training_header(newline_start=True)
    for curr_iter in range(args.num_iters):
        # Necessary because we flip the weights contained between GPU and CPU sometimes.
        model_registrar.to(args.device)

        # Setting the current iterator value for internal logging.
        stg.set_curr_iter(curr_iter)
        if args.vis_every is not None:
            eval_stg.set_curr_iter(curr_iter)

        # Stepping forward the learning rate scheduler and annealers.
        lr_scheduler.step()
        log_writer.add_scalar('train/learning_rate',
                              lr_scheduler.get_lr()[0], curr_iter)
        stg.step_annealers()

        # Zeroing gradients for the upcoming iteration.
        optimizer.zero_grad()
        train_losses = dict()
        for node_type in train_env.NodeType:
            train_losses[node_type] = []
        for scene in np.random.choice(train_scenes, 10):
            for mb_num in range(args.batch_multiplier):
                # Obtaining the batch's training loss.
                timesteps = scene.sample_timesteps(hyperparams['batch_size'])

                # Compute the training loss.
                train_loss_by_type = stg.train_loss(
                    scene, timesteps, max_nodes=hyperparams['batch_size'])
                for node_type, train_loss in train_loss_by_type.items():
                    if train_loss is not None:
                        train_loss = train_loss / (args.batch_multiplier * 10)
                        train_losses[node_type].append(train_loss.item())

                        # Calculating gradients.
                        train_loss.backward()

        # Print training information. Also, no newline here. It's added in at a later line.
        print('{:9} | '.format(curr_iter), end='', flush=True)
        for node_type in train_env.NodeType:
            print('{}:{:10} | '.format(node_type.name[0],
                                       '%.2f' % sum(train_losses[node_type])),
                  end='',
                  flush=True)

        for node_type in train_env.NodeType:
            if len(train_losses[node_type]) > 0:
                log_writer.add_histogram(
                    f"{node_type.name}/train/minibatch_losses",
                    np.asarray(train_losses[node_type]), curr_iter)
                log_writer.add_scalar(f"{node_type.name}/train/loss",
                                      sum(train_losses[node_type]), curr_iter)

        # Clipping gradients.
        if hyperparams['grad_clip'] is not None:
            nn.utils.clip_grad_value_(model_registrar.parameters(),
                                      hyperparams['grad_clip'])

        # Performing a gradient step.
        optimizer.step()

        del train_loss  # TODO Necessary?

        if args.vis_every is not None and (curr_iter +
                                           1) % args.vis_every == 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict random timestep to plot for train data set
                scene = np.random.choice(train_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = stg.predict(scene,
                                          timestep,
                                          ph,
                                          num_samples_z=100,
                                          most_likely_z=False,
                                          all_z=False)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/prediction', fig, curr_iter)

                # Predict random timestep to plot for eval data set
                scene = np.random.choice(eval_scenes)
                timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
                predictions = eval_stg.predict(scene,
                                               timestep,
                                               ph,
                                               num_samples_z=100,
                                               most_likely_z=False,
                                               all_z=False,
                                               max_nodes=4 *
                                               args.eval_batch_size)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction', fig, curr_iter)

                # Plot predicted timestep for random scene in map
                fig, ax = plt.subplots(figsize=(15, 15))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph,
                                                   map=scene.map['PLOT'])
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_map', fig, curr_iter)

                # Predict random timestep to plot for eval data set
                predictions = eval_stg.predict(scene,
                                               timestep,
                                               ph,
                                               num_samples_gmm=50,
                                               most_likely_z=False,
                                               all_z=True,
                                               max_nodes=4 *
                                               args.eval_batch_size)

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(5, 5))
                visualization.visualize_prediction(ax,
                                                   predictions,
                                                   scene.dt,
                                                   max_hl=max_hl,
                                                   ph=ph)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/prediction_all_z', fig, curr_iter)

        if args.eval_every is not None and (curr_iter +
                                            1) % args.eval_every == 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                # Predict batch timesteps for training dataset evaluation
                train_batch_errors = []
                max_scenes = np.min([len(train_scenes), 5])
                for scene in np.random.choice(train_scenes, max_scenes):
                    timesteps = scene.sample_timesteps(args.eval_batch_size)
                    predictions = stg.predict(scene,
                                              timesteps,
                                              ph,
                                              num_samples_z=100,
                                              min_future_timesteps=ph,
                                              max_nodes=4 *
                                              args.eval_batch_size)

                    train_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=train_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(train_batch_errors,
                                            log_writer,
                                            'train',
                                            curr_iter,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                for scene in eval_scenes:
                    timesteps = scene.sample_timesteps(args.eval_batch_size)

                    predictions = eval_stg.predict(scene,
                                                   timesteps,
                                                   ph,
                                                   num_samples_z=100,
                                                   min_future_timesteps=ph,
                                                   max_nodes=4 *
                                                   args.eval_batch_size)

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=scene.map))

                evaluation.log_batch_errors(eval_batch_errors,
                                            log_writer,
                                            'eval',
                                            curr_iter,
                                            bar_plot=['kde'],
                                            box_plot=['ade', 'fde'])

                # Predict maximum likelihood batch timesteps for evaluation dataset evaluation
                eval_batch_errors_ml = []
                for scene in eval_scenes:
                    timesteps = scene.sample_timesteps(scene.timesteps)

                    predictions = eval_stg.predict(scene,
                                                   timesteps,
                                                   ph,
                                                   num_samples_z=1,
                                                   min_future_timesteps=ph,
                                                   most_likely_z=True,
                                                   most_likely_gmm=True)

                    eval_batch_errors_ml.append(
                        evaluation.compute_batch_statistics(
                            predictions,
                            scene.dt,
                            max_hl=max_hl,
                            ph=ph,
                            map=scene.map,
                            node_type_enum=eval_env.NodeType,
                            kde=False))

                evaluation.log_batch_errors(eval_batch_errors_ml, log_writer,
                                            'eval/ml', curr_iter)

                eval_loss = []
                max_scenes = np.min([len(eval_scenes), 25])
                for scene in np.random.choice(eval_scenes, max_scenes):
                    eval_loss.append(eval_stg.eval_loss(scene, timesteps))

                evaluation.log_batch_errors(eval_loss, log_writer, 'eval/loss',
                                            curr_iter)

        else:
            print('{:15} | {:10} | {:14}'.format('', '', ''),
                  end='',
                  flush=True)

        # Here's the newline that ends the current training information printing.
        print('')

        if args.save_every is not None and (curr_iter +
                                            1) % args.save_every == 0:
            model_registrar.save_models(curr_iter)
            print_training_header()
Beispiel #4
0
def main():
    model_dir = os.path.join(args.log_dir,
                             'models_14_Jan_2020_00_24_21eth_no_rob')

    # Load hyperparameters from json
    config_file = os.path.join(model_dir, args.conf)
    if not os.path.exists(config_file):
        raise ValueError('Config json not found!')
    with open(config_file, 'r') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = args.incl_robot_node
    hyperparams['scene_batch_size'] = args.scene_batch_size
    hyperparams['node_resample_train'] = args.node_resample_train
    hyperparams['node_resample_eval'] = args.node_resample_eval
    hyperparams['scene_resample_train'] = args.scene_resample_train
    hyperparams['scene_resample_eval'] = args.scene_resample_eval
    hyperparams['scene_resample_viz'] = args.scene_resample_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding

    output_save_dir = os.path.join(model_dir, 'pred_figs')
    pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)

    eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
    with open(eval_data_path, 'rb') as f:
        eval_env = pickle.load(f, encoding='latin1')

    if eval_env.robot_type is None and hyperparams['incl_robot_node']:
        eval_env.robot_type = eval_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in eval_env.scenes:
            scene.add_robot_from_nodes(eval_env.robot_type)

    print('Loaded evaluation data from %s' % (eval_data_path, ))

    # Creating a dummy environment with a single scene that contains information about the world.
    # When using this code, feel free to use whichever scene index or initial timestep you wish.
    scene_idx = 0

    # You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
    # so that you can immediately start incremental inference from the 3rd timestep onwards.
    init_timestep = 1

    eval_scene = eval_env.scenes[scene_idx]
    online_env = create_online_env(eval_env, hyperparams, scene_idx,
                                   init_timestep)

    model_registrar = ModelRegistrar(model_dir, args.eval_device)
    model_registrar.load_models(iter_num=1999)

    trajectron = OnlineTrajectron(model_registrar, hyperparams,
                                  args.eval_device)

    # If you want to see what different robot futures do to the predictions, uncomment this line as well as
    # related "... += adjustment" lines below.
    # adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)

    # Here's how you'd incrementally run the model, e.g. with streaming data.
    trajectron.set_environment(online_env, init_timestep)

    for timestep in range(init_timestep + 1, eval_scene.timesteps):
        pos_dict = eval_scene.get_clipped_pos_dict(timestep,
                                                   hyperparams['state'])

        robot_present_and_future = None
        if eval_scene.robot is not None:
            robot_present_and_future = eval_scene.robot.get(
                np.array(
                    [timestep, timestep + hyperparams['prediction_horizon']]),
                hyperparams['state'][eval_scene.robot.type],
                padding=0.0)
            robot_present_and_future = np.stack(
                [robot_present_and_future, robot_present_and_future], axis=0)
            # robot_present_and_future += adjustment

        start = time.time()
        preds = trajectron.incremental_forward(
            pos_dict,
            prediction_horizon=12,
            num_samples=25,
            robot_present_and_future=robot_present_and_future)
        end = time.time()
        print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" %
              (timestep, end - start, 1. / (end - start), len(
                  trajectron.nodes), trajectron.scene_graph.get_num_edges()))

        detailed_preds_dict = dict()
        for node in eval_scene.nodes:
            if node in preds:
                detailed_preds_dict[node] = preds[node]

        batch_stats = evaluation.compute_batch_statistics(
            {timestep: detailed_preds_dict},
            eval_scene.dt,
            max_hl=hyperparams['maximum_history_length'],
            ph=hyperparams['prediction_horizon'],
            node_type_enum=online_env.NodeType,
            prune_ph_to_future=True)

        evaluation.print_batch_errors([batch_stats], 'eval', timestep)

        fig, ax = plt.subplots()
        vis.visualize_prediction(ax, {timestep: preds}, eval_scene.dt,
                                 hyperparams['maximum_history_length'],
                                 hyperparams['prediction_horizon'])

        if eval_scene.robot is not None:
            robot_for_plotting = eval_scene.robot.get(
                np.array(
                    [timestep, timestep + hyperparams['prediction_horizon']]),
                hyperparams['state'][eval_scene.robot.type])
            # robot_for_plotting += adjustment

            ax.plot(robot_for_plotting[1:, 1],
                    robot_for_plotting[1:, 0],
                    color='r',
                    linewidth=1.0,
                    alpha=1.0)

            # Current Node Position
            circle = plt.Circle(
                (robot_for_plotting[0, 1], robot_for_plotting[0, 0]),
                0.3,
                facecolor='r',
                edgecolor='k',
                lw=0.5,
                zorder=3)
            ax.add_artist(circle)

        fig.savefig(os.path.join(output_save_dir, f'pred_{timestep}.pdf'),
                    dpi=300)
        plt.close(fig)
            for i, scene in enumerate(tqdm(scenes)):
                for timestep in range(scene.timesteps):
                    predictions = eval_stg.predict(scene,
                                                   np.array([timestep]),
                                                   ph,
                                                   num_samples_z=2000,
                                                   most_likely_z=False,
                                                   min_future_timesteps=8)

                    if not predictions:
                        continue

                    eval_error_dict = evaluation.compute_batch_statistics(
                        predictions,
                        scene.dt,
                        node_type_enum=env.NodeType,
                        max_hl=max_hl,
                        ph=ph,
                        map=scene.map[node_type.name],
                        obs=True)

                    eval_ade_batch_errors = np.hstack(
                        (eval_ade_batch_errors,
                         eval_error_dict[node_type]['ade']))
                    eval_fde_batch_errors = np.hstack(
                        (eval_fde_batch_errors,
                         eval_error_dict[node_type]['fde']))
                    eval_kde_nll = np.hstack(
                        (eval_kde_nll, eval_error_dict[node_type]['kde']))
                    eval_obs_viols = np.hstack(
                        (eval_obs_viols,
                         eval_error_dict[node_type]['obs_viols']))
Beispiel #6
0
def main():
    # Load hyperparameters from json
    if not os.path.exists(args.conf):
        print('Config json not found!')
    with open(args.conf, 'r', encoding='utf-8') as conf_json:
        hyperparams = json.load(conf_json)

    # Add hyperparams from arguments
    hyperparams['dynamic_edges'] = args.dynamic_edges
    hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
    hyperparams[
        'edge_influence_combine_method'] = args.edge_influence_combine_method
    hyperparams['edge_addition_filter'] = args.edge_addition_filter
    hyperparams['edge_removal_filter'] = args.edge_removal_filter
    hyperparams['batch_size'] = args.batch_size
    hyperparams['k_eval'] = args.k_eval
    hyperparams['offline_scene_graph'] = args.offline_scene_graph
    hyperparams['incl_robot_node'] = not args.no_robot_node
    hyperparams['node_freq_mult_train'] = args.node_freq_mult_train
    hyperparams['node_freq_mult_eval'] = args.node_freq_mult_eval
    hyperparams['scene_freq_mult_train'] = args.scene_freq_mult_train
    hyperparams['scene_freq_mult_eval'] = args.scene_freq_mult_eval
    hyperparams['scene_freq_mult_viz'] = args.scene_freq_mult_viz
    hyperparams['edge_encoding'] = not args.no_edge_encoding
    hyperparams['use_map_encoding'] = args.map_encoding
    hyperparams['augment'] = args.augment
    hyperparams['override_attention_radius'] = args.override_attention_radius
    hyperparams['include_B'] = not args.no_B
    hyperparams['reg_B'] = args.reg_B
    hyperparams['reg_B_weight'] = args.reg_B_weight
    hyperparams['zero_R_rows'] = args.zero_R_rows
    hyperparams['reg_A_slew'] = args.reg_A_slew
    hyperparams['reg_A_slew_weight'] = args.reg_A_slew_weight

    print('-----------------------')
    print('| TRAINING PARAMETERS |')
    print('-----------------------')
    print('| batch_size: %d' % args.batch_size)
    print('| device: %s' % args.device)
    print('| eval_device: %s' % args.eval_device)
    print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
    print('| EE state_combine_method: %s' % args.edge_state_combine_method)
    print('| EIE scheme: %s' % args.edge_influence_combine_method)
    print('| dynamic_edges: %s' % args.dynamic_edges)
    print('| robot node: %s' % (not args.no_robot_node))
    print('| edge_addition_filter: %s' % args.edge_addition_filter)
    print('| edge_removal_filter: %s' % args.edge_removal_filter)
    print('| MHL: %s' % hyperparams['minimum_history_length'])
    print('| PH: %s' % hyperparams['prediction_horizon'])
    print('-----------------------')

    log_writer = None
    model_dir = None
    if not args.debug:
        # Create the log and model directiory if they're not present.
        model_dir = os.path.join(
            args.log_dir,
            'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) +
            args.log_tag)
        pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)

        # Save config to model directory
        with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
            json.dump(hyperparams, conf_json)

        log_writer = SummaryWriter(log_dir=model_dir)

    # Load training and evaluation environments and scenes
    train_data_path = os.path.join(args.data_dir, args.train_data_dict)
    with open(train_data_path, 'rb') as f:
        train_env = dill.load(f, encoding='latin1')

    for attention_radius_override in args.override_attention_radius:
        node_type1, node_type2, attention_radius = attention_radius_override.split(
            ' ')
        train_env.attention_radius[(node_type1,
                                    node_type2)] = float(attention_radius)

    if train_env.robot_type is None and hyperparams['incl_robot_node']:
        train_env.robot_type = train_env.NodeType[
            0]  # TODO: Make more general, allow the user to specify?
        for scene in train_env.scenes:
            scene.add_robot_from_nodes(
                train_env.robot_type,
                hyperparams=hyperparams,
                min_timesteps=hyperparams['minimum_history_length'] + 1 +
                hyperparams['prediction_horizon'])

    train_scenes = train_env.scenes
    train_dataset = EnvironmentDataset(
        train_env,
        hyperparams['state'],
        hyperparams['pred_state'],
        scene_freq_mult=hyperparams['scene_freq_mult_train'],
        node_freq_mult=hyperparams['node_freq_mult_train'],
        hyperparams=hyperparams,
        min_history_timesteps=hyperparams['minimum_history_length'],
        min_future_timesteps=hyperparams['prediction_horizon'])
    train_data_loader = utils.data.DataLoader(
        train_dataset.dataset,
        collate_fn=collate,
        pin_memory=False if args.device == torch.device('cpu') else True,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.preprocess_workers)

    print(f"Loaded training data from {train_data_path}")

    eval_scenes = []
    if args.eval_every is not None:
        eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
        with open(eval_data_path, 'rb') as f:
            eval_env = dill.load(f, encoding='latin1')

        for attention_radius_override in args.override_attention_radius:
            node_type1, node_type2, attention_radius = attention_radius_override.split(
                ' ')
            eval_env.attention_radius[(node_type1,
                                       node_type2)] = float(attention_radius)

        if eval_env.robot_type is None and hyperparams['incl_robot_node']:
            eval_env.robot_type = eval_env.NodeType[
                0]  # TODO: Make more general, allow the user to specify?
            for scene in eval_env.scenes:
                scene.add_robot_from_nodes(
                    eval_env.robot_type,
                    hyperparams=hyperparams,
                    min_timesteps=hyperparams['minimum_history_length'] + 1 +
                    hyperparams['prediction_horizon'])

        eval_scenes = eval_env.scenes
        eval_dataset = EnvironmentDataset(
            eval_env,
            hyperparams['state'],
            hyperparams['pred_state'],
            scene_freq_mult=hyperparams['scene_freq_mult_eval'],
            node_freq_mult=hyperparams['node_freq_mult_eval'],
            hyperparams=hyperparams,
            min_history_timesteps=hyperparams['minimum_history_length'],
            min_future_timesteps=hyperparams['prediction_horizon'])
        eval_data_loader = utils.data.DataLoader(
            eval_dataset.dataset,
            collate_fn=collate,
            pin_memory=False
            if args.eval_device == torch.device('cpu') else True,
            batch_size=args.eval_batch_size,
            shuffle=True,
            num_workers=args.preprocess_workers)

        print(f"Loaded evaluation data from {eval_data_path}")

    # Offline Calculate Scene Graph
    if hyperparams['offline_scene_graph'] == 'yes':
        print(f"Offline calculating scene graphs")
        for i, scene in enumerate(train_scenes):
            scene.calculate_scene_graph(train_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
        print(f"Created Scene Graphs for Training Scenes")

        for i, scene in enumerate(eval_scenes):
            scene.calculate_scene_graph(eval_env.attention_radius,
                                        hyperparams['edge_addition_filter'],
                                        hyperparams['edge_removal_filter'])
        print(f"Created Scene Graphs for Evaluation Scenes")

    model_registrar = ModelRegistrar(model_dir, args.device)

    mats = MATS(model_registrar, hyperparams, log_writer, args.device)

    mats.set_environment(train_env)
    mats.set_annealing_params()
    print('Created Training Model.')

    eval_mats = None
    if args.eval_every is not None or args.vis_every is not None:
        eval_mats = MATS(model_registrar, hyperparams, log_writer,
                         args.eval_device)
        eval_mats.set_environment(eval_env)
        eval_mats.set_annealing_params()
    print('Created Evaluation Model.')

    optimizer = optim.Adam(
        [{
            'params':
            model_registrar.get_all_but_name_match('map_encoder').parameters()
        }, {
            'params':
            model_registrar.get_name_match('map_encoder').parameters(),
            'lr': 0.0008
        }],
        lr=hyperparams['learning_rate'])
    # Set Learning Rate
    if hyperparams['learning_rate_style'] == 'const':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif hyperparams['learning_rate_style'] == 'exp':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=hyperparams['learning_decay_rate'])

    #################################
    #           TRAINING            #
    #################################
    curr_iter = 0
    for epoch in range(1, args.train_epochs + 1):
        model_registrar.to(args.device)
        train_dataset.augment = args.augment
        train_data_iterator = iter(train_data_loader)
        pbar = tqdm(total=len(train_data_loader), ncols=80)
        for _ in range(0, len(train_data_loader), args.batch_multiplier):
            mats.set_curr_iter(curr_iter)
            mats.step_annealers()

            # Zeroing gradients.
            optimizer.zero_grad()

            train_losses = list()
            for mb_num in range(args.batch_multiplier):
                try:
                    train_loss = mats.train_loss(
                        next(train_data_iterator),
                        include_B=hyperparams['include_B'],
                        reg_B=hyperparams['reg_B'],
                        zero_R_rows=hyperparams['zero_R_rows']
                    ) / args.batch_multiplier
                    train_losses.append(train_loss.item())
                    train_loss.backward()
                except StopIteration:
                    break

            pbar.update(args.batch_multiplier)
            pbar.set_description(f"Epoch {epoch} L: {sum(train_losses):.2f}")

            # Clipping gradients.
            if hyperparams['grad_clip'] is not None:
                nn.utils.clip_grad_value_(model_registrar.parameters(),
                                          hyperparams['grad_clip'])

            # Optimizer step.
            optimizer.step()

            # Stepping forward the learning rate scheduler and annealers.
            lr_scheduler.step()

            if not args.debug:
                log_writer.add_scalar(f"train/learning_rate",
                                      lr_scheduler.get_last_lr()[0], curr_iter)
                log_writer.add_scalar(f"train/loss", sum(train_losses),
                                      curr_iter)

            curr_iter += 1

        train_dataset.augment = False
        if args.eval_every is not None or args.vis_every is not None:
            eval_mats.set_curr_iter(epoch)

        #################################
        #        VISUALIZATION          #
        #################################
        if args.vis_every is not None and not args.debug and epoch % args.vis_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            with torch.no_grad():
                index = train_dataset.dataset.index
                rand_elem = index[random.randrange(len(index))]
                scene, timestep = rand_elem[0], np.array([rand_elem[1]])
                pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, state_lengths_in_order = mats.predict(
                    scene,
                    timestep,
                    ph,
                    min_future_timesteps=ph,
                    include_B=hyperparams['include_B'],
                    zero_R_rows=hyperparams['zero_R_rows'])

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    pred_dists,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None,
                    robot_node=scene.robot)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('train/all_modes', fig, epoch)

                # Plot A, B, Q matrices.
                figs = visualization.visualize_mats(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'train/{"ABQ"[idx]}_mat', fig,
                                          epoch)

                # Plot most-likely A, B, Q matrices across time.
                figs = visualization.visualize_mats_time(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'train/ml_{"ABQ"[idx]}_mat', fig,
                                          epoch)

                model_registrar.to(args.eval_device)

                # Predict random timestep to plot for eval data set
                index = eval_dataset.dataset.index
                rand_elem = index[random.randrange(len(index))]
                scene, timestep = rand_elem[0], np.array([rand_elem[1]])

                pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, state_lengths_in_order = eval_mats.predict(
                    scene,
                    timestep,
                    ph,
                    min_future_timesteps=ph,
                    include_B=hyperparams['include_B'],
                    zero_R_rows=hyperparams['zero_R_rows'])

                # Plot predicted timestep for random scene
                fig, ax = plt.subplots(figsize=(10, 10))
                visualization.visualize_prediction(
                    ax,
                    pred_dists,
                    scene.dt,
                    max_hl=max_hl,
                    ph=ph,
                    map=scene.map['VISUALIZATION']
                    if scene.map is not None else None,
                    robot_node=scene.robot)
                ax.set_title(f"{scene.name}-t: {timestep}")
                log_writer.add_figure('eval/all_modes', fig, epoch)

                # Plot A, B, Q matrices.
                figs = visualization.visualize_mats(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'eval/{"ABQ"[idx]}_mat', fig, epoch)

                # Plot most-likely A, B, Q matrices across time.
                figs = visualization.visualize_mats_time(
                    As, Bs, Qs, pred_dists[timestep.item()],
                    state_lengths_in_order)
                for idx, fig in enumerate(figs):
                    fig.suptitle(f"{scene.name}-t: {timestep}")
                    log_writer.add_figure(f'eval/ml_{"ABQ"[idx]}_mat', fig,
                                          epoch)

        #################################
        #           EVALUATION          #
        #################################
        if args.eval_every is not None and not args.debug and epoch % args.eval_every == 0 and epoch > 0:
            max_hl = hyperparams['maximum_history_length']
            ph = hyperparams['prediction_horizon']
            model_registrar.to(args.eval_device)
            with torch.no_grad():
                # Calculate evaluation loss
                eval_losses = []
                print(f"Starting Evaluation @ epoch {epoch}")
                pbar = tqdm(eval_data_loader, ncols=80)
                for batch in pbar:
                    eval_loss = eval_mats.eval_loss(
                        batch,
                        include_B=hyperparams['include_B'],
                        zero_R_rows=hyperparams['zero_R_rows'])
                    pbar.set_description(
                        f"Epoch {epoch} L: {eval_loss.item():.2f}")
                    eval_losses.append({'full_state': {'nll': [eval_loss]}})
                    del batch

                evaluation.log_batch_errors(eval_losses, log_writer, 'eval',
                                            epoch)

                # Predict batch timesteps for evaluation dataset evaluation
                eval_batch_errors = []
                eval_mintopk_errors = []
                for scene, times in tqdm(
                        eval_dataset.dataset.scene_time_dict.items(),
                        desc='Evaluation',
                        ncols=80):
                    timesteps = np.random.choice(times, args.eval_batch_size)

                    pred_dists, non_rob_rows, As, Bs, Qs, affine_terms, _ = eval_mats.predict(
                        scene,
                        timesteps,
                        ph=ph,
                        min_future_timesteps=ph,
                        include_B=hyperparams['include_B'],
                        zero_R_rows=hyperparams['zero_R_rows'])

                    eval_batch_errors.append(
                        evaluation.compute_batch_statistics(
                            pred_dists,
                            max_hl=max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType,
                            map=None))  # scene.map))

                    eval_mintopk_errors.append(
                        evaluation.compute_mintopk_statistics(
                            pred_dists,
                            max_hl,
                            ph=ph,
                            node_type_enum=eval_env.NodeType))

                evaluation.log_batch_errors(eval_batch_errors, log_writer,
                                            'eval', epoch)

                evaluation.plot_mintopk_curves(eval_mintopk_errors, log_writer,
                                               'eval', epoch)

        if args.save_every is not None and args.debug is False and epoch % args.save_every == 0:
            model_registrar.save_models(epoch)