def run():
    print("Summarizing all waypoint locations")
    only_process_test_sites = True
    write_all_wifi_data = False

    data_folder = utils.get_data_folder()
    summary_path = data_folder / "file_summary.csv"
    combined_waypoints_path = data_folder / "train_waypoints_timed.csv"
    if combined_waypoints_path.is_file():
        return
    # combined_train_wifi_times_path = data_folder / "train_wifi_times.csv"
    # combined_test_wifi_times_path = data_folder / "test_wifi_times.csv"
    stratified_holdout_path = data_folder / 'holdout_ids.csv'
    combined_all_wifi_folder = data_folder / 'train'
    df = pd.read_csv(summary_path)
    holdout = pd.read_csv(stratified_holdout_path)

    # Loop over all file paths and compare the parquet and pickle files one by one
    all_waypoints = []
    all_train_wifi_times = []
    all_test_wifi_times = []
    all_wifi_data = []
    for i in range(df.shape[0]):
        # if i < 26900:
        #   continue

        print(f"Trajectory {i+1} of {df.shape[0]}")
        if (not only_process_test_sites
                or df.test_site[i]) and df.num_wifi[i] > 0:
            pickle_path = data_folder / (
                str(Path(df.ext_path[i]).with_suffix("")) + "_reshaped.pickle")
            with open(pickle_path, "rb") as f:
                trajectory = pickle.load(f)

            if df['mode'][i] != 'test':
                waypoints = trajectory['waypoint']
                num_waypoints = waypoints.shape[0]

                # Add meta columns
                for c in ['site_id', 'mode', 'fn', 'text_level']:
                    waypoints[c] = df[c][i]

                # Add whether it is a train or validation trajectory
                waypoints['mode'] = holdout['mode'][holdout.fn ==
                                                    df.fn[i]].values[0]

                # Add the waypoint type
                waypoint_types = np.repeat('middle', num_waypoints)
                waypoint_types[0] = 'first'
                waypoint_types[num_waypoints - 1] = 'last'
                waypoints['type'] = waypoint_types
                waypoints['id'] = np.arange(num_waypoints)
                waypoints['num_waypoints'] = num_waypoints

            # Add the most recent wifi times that are closest to the waypoint
            # timestamps
            wifi_t1_times = np.unique(trajectory['wifi'].t1_wifi)
            assert np.all(np.diff(wifi_t1_times) > 0)
            wifi_last_t2_times = trajectory['wifi'].groupby(
                't1_wifi')['t2_wifi'].aggregate("max").values
            num_wifi_obs = trajectory['wifi'].groupby(
                't1_wifi')['t1_wifi'].aggregate("count").values
            try:
                assert wifi_t1_times.size == wifi_last_t2_times.size
                assert np.sum(np.diff(wifi_last_t2_times) < -1) <= 1
                assert np.all(wifi_last_t2_times < wifi_t1_times) or (
                    df['mode'][i] == 'test')
            except:
                import pdb
                pdb.set_trace()
                x = 1

            if df['mode'][i] != 'test':
                waypoint_wifi_times = np.zeros(num_waypoints, dtype=np.int64)
                for j in range(num_waypoints):
                    wifi_id = max(
                        0, (wifi_last_t2_times <= waypoints.time[j]).sum() - 1)
                    waypoint_wifi_times[j] = wifi_last_t2_times[wifi_id]
                waypoints['last_wifi_t2_time'] = waypoint_wifi_times
                waypoints['trajectory_wifi_time'] = waypoint_wifi_times - (
                    wifi_t1_times[0])
                waypoint_times = waypoints.time.values
                waypoints['trajectory_waypoint_time'] = waypoint_times - (
                    waypoint_times[0])
                waypoints['first_waypoint_time'] = waypoint_times[0]

                # Reorder the columns
                cols = waypoints.columns.tolist()
                reordered_cols = cols[:1] + cols[4:] + cols[1:4]
                waypoints = waypoints[reordered_cols]

                all_waypoints.append(waypoints)

            if write_all_wifi_data:
                wifi_data = trajectory['wifi'].copy()
                for c in ['site_id', 'mode', 'fn', 'level']:
                    wifi_data[c] = df[c][i]
                cols = wifi_data.columns.tolist()
                reordered_cols = cols[6:] + cols[:6]
                wifi_data = wifi_data[reordered_cols]
                if 'wifi_waypoints' in trajectory:
                    wifi_wp = trajectory['wifi_waypoints']
                    wifi_wp.sort_values(["t1_wifi", "t2_wifi"],
                                        ascending=[True, False],
                                        inplace=True)
                    wifi_wp_map = wifi_wp.groupby(['t1_wifi'
                                                   ]).first().reset_index()[[
                                                       't1_wifi',
                                                       'waypoint_interp_x',
                                                       'waypoint_interp_y'
                                                   ]]

                    wifi_data = wifi_data.merge(wifi_wp_map, on='t1_wifi')
                else:
                    wifi_data['waypoint_interp_x'] = np.nan
                    wifi_data['waypoint_interp_y'] = np.nan
                all_wifi_data.append(wifi_data)

            wifi_times = pd.DataFrame({
                'site_id':
                df['site_id'][i],
                'mode':
                df['mode'][i],
                'fn':
                df['fn'][i],
                'level':
                df['level'][i],
                'wifi_t1_times':
                wifi_t1_times,
                'wifi_last_t2_times':
                wifi_last_t2_times,
                'trajectory_index':
                np.arange(wifi_last_t2_times.size),
                'num_wifi_obs':
                num_wifi_obs,
            })

            if df['mode'][i] == 'test':
                wifi_times['first_last_t2_time'] = wifi_last_t2_times[0]
                all_test_wifi_times.append(wifi_times)
            else:
                wifi_times['first_waypoint_time'] = waypoint_times[0]
                all_train_wifi_times.append(wifi_times)

    # Write the combined waypoints to disk
    combined_waypoints = pd.concat(all_waypoints)
    combined_waypoints.sort_values(["site_id", "first_waypoint_time", "time"],
                                   inplace=True)
    combined_waypoints.to_csv(combined_waypoints_path, index=False)

    # # Write the combined wifi times to disk
    # combined_train_wifi_times = pd.concat(all_train_wifi_times)
    # combined_train_wifi_times.sort_values(
    #   ["site_id", "first_waypoint_time", "wifi_t1_times"], inplace=True)
    # combined_train_wifi_times.to_csv(combined_train_wifi_times_path, index=False)

    # combined_test_wifi_times = pd.concat(all_test_wifi_times)
    # combined_test_wifi_times.sort_values(
    #   ["site_id", "first_last_t2_time", "wifi_t1_times"], inplace=True)
    # combined_test_wifi_times.to_csv(combined_test_wifi_times_path, index=False)

    # Write the raw wifi data to disk
    if write_all_wifi_data:
        test_floors = utils.get_test_floors(data_folder)
        combined_all_wifi = pd.concat(all_wifi_data)
        combined_all_wifi.sort_values(["site_id", "fn", "mode"], inplace=True)
        all_levels = [
            l if m != 'test' else test_floors[fn]
            for (l, m,
                 fn) in zip(combined_all_wifi.level, combined_all_wifi['mode'],
                            combined_all_wifi.fn)
        ]
        combined_all_wifi['level'] = np.array(all_levels)
        sites = np.sort(np.unique(combined_all_wifi.site_id.values))
        for site_id, site in enumerate(sites):
            print(f"Site {site_id+1} of {len(sites)}")
            combined_all_wifi_site = combined_all_wifi[
                combined_all_wifi.site_id.values == site]

            # Map the levels from a reference submission for the test data
            levels = np.sort(np.unique(combined_all_wifi_site.level.values))
            for l in levels:
                combined_all_wifi_floor = combined_all_wifi_site[
                    combined_all_wifi_site.level.values == l]
                combined_all_wifi_floor.sort_values(["mode", "t1_wifi"],
                                                    inplace=True)
                text_level = df.text_level[
                    df.fn == combined_all_wifi_floor.fn.values[-1]].values[0]

                combined_all_wifi_path = combined_all_wifi_folder / site / text_level / (
                    'all_wifi.csv')
                combined_all_wifi_floor.to_csv(combined_all_wifi_path,
                                               index=False)
示例#2
0
    #   x=1

    return res.x


data_folder = utils.get_data_folder()
summary_path = data_folder / 'file_summary.csv'
stratified_holdout_path = data_folder / 'holdout_ids.csv'
model_folder = data_folder.parent / 'Models' / models_group_name
if not 'df' in locals() or not 'holdout_df' in locals() or (
        not 'test_waypoint_times'
        in locals()) or not 'test_floors' in locals():
    df = pd.read_csv(summary_path)
    holdout_df = pd.read_csv(stratified_holdout_path)
    test_waypoint_times = utils.get_test_waypoint_times(data_folder)
    test_floors = utils.get_test_floors(data_folder)

aggregate_scores = np.zeros((len(utils.TEST_SITES), 2))
test_preds = {}
for analysis_site_id, analysis_site in enumerate(utils.TEST_SITES):
    site_model_folder = model_folder / analysis_site
    Path(site_model_folder).mkdir(parents=True, exist_ok=True)

    site_df = df[(df.site_id == analysis_site) & (df.num_wifi > 0)]
    if mode != 'test':
        site_df = site_df[site_df['mode'] != 'test']
        valid_paths = holdout_df.ext_path[(holdout_df['mode'] == 'valid') & (
            holdout_df.site_id == analysis_site)].tolist()
        with pd.option_context('mode.chained_assignment', None):
            site_df['mode'] = site_df['ext_path'].apply(
                lambda x: 'valid' if (x in valid_paths) else 'train')
示例#3
0
def run(mode):
    print("Processing time leak (edge trajectories)")
    debug_site = [None, 0][0]
    use_multiprocessing = False
    test_preds_source = 'test - 2021-05-15 051944.csv'
    test_override_floors = False

    data_folder = utils.get_data_folder()
    test_override_ext = '_floor_override' if (mode == 'test'
                                              and test_override_floors) else ''
    save_path = data_folder / (mode + '_edge_positions_v3' +
                               test_override_ext + '.csv')
    if save_path.is_file():
        return

    summary_path = data_folder / 'file_summary.csv'
    test_preds_path = data_folder / 'submissions' / test_preds_source
    stratified_holdout_path = data_folder / 'holdout_ids.csv'
    device_id_path = data_folder / 'device_ids.pickle'
    ordered_device_time_path = data_folder / 'inferred_device_ids.csv'
    with open(device_id_path, 'rb') as f:
        device_ids = pickle.load(f)
    public_private_test_leaks = {
        'ff141af01177f34e9caa7a12': ('start', 3, 203.11885, 97.310814),
        'f973ee415265be4addc457b1': ('start', -1, 20.062187, 99.66188),
        '23b4c8eb4b41d75946285461': ('end', 2, 60.205635, 102.28055),
        '5582270fcaee1f580de9006f': ('end', 0, 97.8957, 28.9133),
        'b51a662297b90657f0b03b44': ('start', 1, 112.39258, 233.72379),
    }
    df = pd.read_csv(summary_path)
    holdout_df = pd.read_csv(stratified_holdout_path)
    test_floors = utils.get_test_floors(
        data_folder, debug_test_floor_override=test_override_floors)
    test_preds = pd.read_csv(test_preds_path)
    test_preds = utils.override_test_floor_errors(
        test_preds, debug_test_floor_override=test_override_floors)
    test_preds['fn'] = [
        spt.split('_')[1] for spt in test_preds.site_path_timestamp
    ]
    test_preds['timestamp'] = [
        int(spt.split('_')[2]) for spt in test_preds.site_path_timestamp
    ]
    for test_fn in test_floors:
        assert test_preds.floor[test_preds.fn ==
                                test_fn].values[0] == test_floors[test_fn]

    device_time_path = pd.read_csv(ordered_device_time_path)
    device_time_path['time'] = device_time_path['start_time']
    test_rows = np.where(device_time_path['mode'].values == "test")[0]
    device_time_path.loc[
        test_rows,
        'time'] = device_time_path['first_last_wifi_time'].values[test_rows]
    device_time_path.sort_values(['device_id', 'time'], inplace=True)

    sites = df.iloc[df.test_site.values].groupby(['site_id'
                                                  ]).size().reset_index()
    if debug_site is not None:
        sites = sites.iloc[debug_site:(debug_site + 1)]
    sites = sites.site_id.values

    if use_multiprocessing:
        with mp.Pool(processes=mp.cpu_count() - 1) as pool:
            results = [
                pool.apply_async(extract_floor_start_end,
                                 args=(data_folder, s, df, holdout_df,
                                       test_preds, device_time_path, mode,
                                       device_ids, public_private_test_leaks))
                for s in sites
            ]
            all_outputs = [p.get() for p in results]
    else:
        all_outputs = []
        for site_id, analysis_site in enumerate(sites):
            print(f"Processing site {site_id+1} of {len(sites)}")
            all_outputs.append(
                extract_floor_start_end(data_folder, analysis_site, df,
                                        holdout_df, test_preds,
                                        device_time_path, mode, device_ids,
                                        public_private_test_leaks))

    # Save the combined results
    combined = pd.concat(all_outputs)
    combined.to_csv(save_path, index=False)
示例#4
0
def run(mode="test", consider_multiprocessing=True, overwrite_output=False):
    print("Non-parametric WiFi model")
    models_group_name = 'non_parametric_wifi'
    overwrite_models = True
    recompute_grouped_data = not True
    # config = {
    #   'min_train_points': 10, # Ignore bssid with few observations
    #   'min_train_fns': 1, # Ignore bssid with few trajectories
    #   'delay_decay_penalty_exp_base': 0.62, # Base for bssid weight decay as a f of delay to compute the shared bssid fraction
    #   'inv_fn_count_penalty_exp': 0.1, # Exponent to give more weight to rare bssids to compute the shared bssid fraction
    #   'non_shared_penalty_start': 1.0, # Threshold below which the shared wifi fraction gets penalized in the distance calculation
    #   'non_shared_penalty_exponent': 2.2, # Exponent to penalize the non shared wifi fraction
    #   'non_shared_penalty_constant': 75, # Multiplicative constant to penalize the non shared wifi fraction
    #   'delay_decay_exp_base': 0.925, # Base for shared bssid weight decay as a f of delay
    #   'inv_fn_count_distance_exp': 0.1, # Exponent to give more weight to rare bssids to compute the weighted mean distance
    #   'unique_model_frequencies': False, # Discard bssid's with changing freqs
    #   'time_range_max_strength': 3, # Group wifi observations before and after each observation and retain the max strength
    #   'limit_train_near_waypoints': not True, # Similar to "snap to grid" - You likely want to set this to False eventually to get more granular predictions
    #   }
    config = {
        'min_train_points': 5,  # Ignore bssid with few observations
        'min_train_fns': 1,  # Ignore bssid with few trajectories
        'delay_decay_penalty_exp_base':
        0.8,  # Base for bssid weight decay as a f of delay to compute the shared bssid fraction
        'inv_fn_count_penalty_exp':
        0.0,  # Exponent to give more weight to rare bssids to compute the shared bssid fraction
        'non_shared_penalty_start':
        1.0,  # Threshold below which the shared wifi fraction gets penalized in the distance calculation
        'non_shared_penalty_exponent':
        2.0,  # Exponent to penalize the non shared wifi fraction
        'non_shared_penalty_constant':
        50,  # Multiplicative constant to penalize the non shared wifi fraction
        'delay_decay_exp_base':
        0.92,  # Base for shared bssid weight decay as a f of delay
        'inv_fn_count_distance_exp':
        0.2,  # Exponent to give more weight to rare bssids to compute the weighted mean distance
        'unique_model_frequencies':
        False,  # Discard bssid's with changing freqs
        'time_range_max_strength':
        1e-5,  # Group wifi observations before and after each observation and retain the max strength
        'limit_train_near_waypoints':
        False  # Similar to "snap to grid" - You likely want to set this to False eventually to get more granular predictions
    }

    debug_floor = [None, 16][0]
    debug_fn = [None, '5dd374df44333f00067aa198'][0]
    store_all_wifi_predictions = False
    store_full_wifi_predictions = not config[
        'limit_train_near_waypoints']  # Required for the current combined optimization
    only_public_test_preds = False
    reference_submission_ext = 'non_parametric_wifi - valid - 2021-03-30 091444.csv'
    bogus_test_floors_to_train_all_test_models = False
    test_override_floors = False

    data_folder = utils.get_data_folder()
    summary_path = data_folder / 'file_summary.csv'
    stratified_holdout_path = data_folder / 'holdout_ids.csv'
    leaderboard_types_path = data_folder / 'leaderboard_type.csv'
    preds_folder = data_folder.parent / 'Models' / models_group_name / (
        'predictions')
    pathlib.Path(preds_folder).mkdir(parents=True, exist_ok=True)
    if store_full_wifi_predictions:
        file_ext = models_group_name + ' - ' + mode + ' - full distances.pickle'
        full_predictions_path = preds_folder / file_ext

        if full_predictions_path.is_file() and (not overwrite_output):
            return

    reference_submission_path = data_folder / reference_submission_ext
    df = pd.read_csv(summary_path)
    holdout_df = pd.read_csv(stratified_holdout_path)
    test_waypoint_times = utils.get_test_waypoint_times(data_folder)
    test_floors = utils.get_test_floors(
        data_folder, debug_test_floor_override=test_override_floors)
    leaderboard_types = pd.read_csv(leaderboard_types_path)
    test_type_mapping = {
        fn: t
        for (fn, t) in zip(leaderboard_types.fn, leaderboard_types['type'])
    }
    reference_submission = pd.read_csv(reference_submission_path)

    assert store_full_wifi_predictions == (
        not config['limit_train_near_waypoints'])

    if bogus_test_floors_to_train_all_test_models and mode == 'test':
        print(
            "WARNING: bogus shuffling of test floors to train all floor models"
        )
        test_floors = utils.get_test_floors(data_folder)
        site_floors = df.iloc[df.test_site.values].groupby(
            ['site_id', 'text_level']).size().reset_index()
        site_floors['level'] = [
            utils.TEST_FLOOR_MAPPING[t] for t in (site_floors.text_level)
        ]
        site_floors['num_test_counts'] = 0
        first_floor_fns = {s: [] for s in np.unique(site_floors.site_id)}
        repeated_floor_fns = {s: [] for s in np.unique(site_floors.site_id)}
        for fn in test_floors:
            site = df.site_id[df.fn == fn].values[0]
            increment_row = np.where((site_floors.site_id == site) & (
                site_floors.level == test_floors[fn]))[0][0]
            site_floors.loc[increment_row, 'num_test_counts'] += 1
            if site_floors.num_test_counts.values[increment_row] > 1:
                repeated_floor_fns[site].append(fn)
            else:
                first_floor_fns[site].append(fn)

        non_visited_floor_ids = np.where(site_floors.num_test_counts == 0)[0]
        for i, non_visited_id in enumerate(non_visited_floor_ids):
            site = site_floors.site_id.values[non_visited_id]
            if repeated_floor_fns[site]:
                override_fn = repeated_floor_fns[site].pop()
            else:
                override_fn = first_floor_fns[site].pop()
            test_floors[override_fn] = site_floors.level.values[non_visited_id]

        # Verify that now all floors contain at least one test fn
        site_floors['num_test_counts'] = 0
        for fn in test_floors:
            site = df.site_id[df.fn == fn].values[0]
            increment_row = np.where((site_floors.site_id == site) & (
                site_floors.level == test_floors[fn]))[0][0]
            site_floors.loc[increment_row, 'num_test_counts'] += 1

    if debug_fn is not None:
        debug_fn_row = np.where(df.fn.values == debug_fn)[0][0]
        debug_fn_site = df.site_id.values[debug_fn_row]
        debug_fn_level = df.text_level.values[debug_fn_row]
        site_floors = df.iloc[df.test_site.values].groupby(
            ['site_id', 'text_level']).size().reset_index()
        debug_floor = np.where((site_floors.site_id == debug_fn_site) & (
            site_floors.text_level == debug_fn_level))[0][0]

    use_multiprocessing = consider_multiprocessing and (debug_fn is None) and (
        debug_floor is None)
    all_outputs = non_parametric_wifi_utils.multiple_floors_train_predict(
        config, df, debug_floor, reference_submission, use_multiprocessing,
        models_group_name, mode, holdout_df, test_floors,
        recompute_grouped_data, overwrite_models, test_type_mapping,
        only_public_test_preds, test_waypoint_times,
        store_all_wifi_predictions, store_full_wifi_predictions, debug_fn)

    test_preds = {
        k: v
        for d in [o[0] for o in all_outputs] for k, v in d.items()
    }
    valid_preds = [r for l in [o[1] for o in all_outputs] for r in l]
    all_wifi_predictions = [r for l in [o[2] for o in all_outputs] for r in l]
    full_wifi_predictions = dict(ChainMap(*[o[3] for o in all_outputs
                                            if o[3]]))

    Path(preds_folder).mkdir(parents=True, exist_ok=True)
    if store_full_wifi_predictions:
        with open(full_predictions_path, 'wb') as handle:
            pickle.dump(full_wifi_predictions,
                        handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
    if mode == 'test':
        submission = utils.convert_to_submission(data_folder, test_preds)
        submission_ext = models_group_name + ' - test.csv'
        submission.to_csv(preds_folder / submission_ext, index=False)
    elif debug_floor is None:
        preds_df = pd.DataFrame(valid_preds)
        print(f"Mean validation error: {preds_df.error.values.mean():.2f}")
        preds_path = preds_folder / (models_group_name + ' - valid.csv')
        preds_df.to_csv(preds_path, index=False)

        if store_all_wifi_predictions:
            all_wifi_preds_df = pd.DataFrame(all_wifi_predictions)
            all_wifi_preds_df.sort_values(["site", "fn", "time"], inplace=True)
            preds_path = preds_folder / (models_group_name +
                                         ' - all wifi validation.csv')
            all_wifi_preds_df.to_csv(preds_path, index=False)

        holdout_unweighted = np.sqrt(preds_df.squared_error.values).mean()
        print(f"Holdout unweighted aggregate loss: {holdout_unweighted:.2f}")
示例#5
0
    def fit(self, mode, train, valid, predict, model_folder,
            override_model_ext):
        train_model = override_model_ext is None

        if self.num_independent_segments == 1:
            num_outputs = 1 if self.distance_model else 2
            self.nn = DistanceZNN(
                ser_limit=self.config['ser_limit'],
                n_device_ids=self.config['device_map_count'][1],
                n_units_cnn1=self.config['n_units_cnn1'],
                n_units_rnn=self.config['n_units_rnn'],
                n_units_cnn2=self.config['n_units_cnn2'],
                n_inputs=len(self.config['sensor_cols']),
                num_recurrent_layers=self.config['num_recurrent_layers'],
                bidirectional_rnn=self.config['bidirectional_rnn'],
                num_outputs=num_outputs,
            )
        else:
            self.nn = RelativeMovementZNN(
                ser_limit=self.config['ser_limit'],
                n_units=self.config['n_units'],
                n_inputs=len(self.config['sensor_cols']),
                num_recurrent_layers=self.config['num_recurrent_layers'],
                bidirectional_rnn=self.config['bidirectional_rnn'],
            )
        self.nn.to(self.config['device'])

        # Train phase
        if not train_model:
            model_str = override_model_ext
            override_model_path = model_folder / (override_model_ext + '.pt')
            override_model_state_dict = torch.load(override_model_path)
            self.nn.load_state_dict(override_model_state_dict)
        else:
            record_time = str(datetime.datetime.now())[:19]
            model_str = mode + ' - ' + record_time
            model_save_path = model_folder / (model_str + '.pt')

            train_loader, _ = self.get_data_loader(
                'train', train, self.config['train_samples_per_epoch'])
            valid_loader, _ = (
                None, None) if valid is None else (self.get_data_loader(
                    'valid', valid, self.config['valid_samples_per_epoch']))

            optimizer_f = lambda par: torch.optim.Adam(par,
                                                       lr=self.config['lr'])
            optimizer = optimizer_f(self.nn.parameters())
            if self.config.get('scheduler', None) is not None:
                self.scheduler = self.config.get('scheduler')(optimizer)
            met_hist = []
            best_valid = float('inf')
            best_train = float('inf')
            for epoch in range(self.config['n_epochs']):
                print(f"Epoch {epoch+1} of {self.config['n_epochs']}")
                start_time = time.time()
                self.nn.train()
                avg_train_loss = 0
                all_preds = []
                all_y = []

                for batch_id, d in enumerate(train_loader):
                    # print(batch_id)
                    optimizer.zero_grad()

                    for k in d.keys():
                        if k in ['long_keys']:
                            d[k] = d[k].to(self.config['device']).long()
                        else:
                            d[k] = d[k].to(self.config['device']).float()
                    preds = self.nn(d)
                    all_preds.append(preds.detach().cpu().numpy())
                    batch_y = d['y']
                    all_y.append(batch_y.cpu().numpy())

                    if self.distance_model:
                        loss = nn.MSELoss()(preds, batch_y)
                    else:
                        # loss = torch.dist(preds, batch_y, 2)
                        loss = nn.L1Loss()(preds, batch_y)
                    avg_train_loss += loss.detach().cpu() / len(train_loader)

                    if epoch > 0:
                        loss.backward()
                    optimizer.step()

                self.nn.eval()
                train_preds = np.concatenate(all_preds)
                train_y = np.concatenate(all_y)
                error = train_y - train_preds
                if self.distance_model:
                    train_loss = np.sqrt((error[:, 0]**2).mean())
                    train_mae = np.abs(error[:, 0]).mean()
                else:
                    distance_errors = np.sqrt((error**2).sum(1))
                    train_loss = distance_errors.mean()
                    train_mae = np.abs(error).mean()

                if self.config.get('scheduler', None) is not None:
                    self.scheduler.step()

                train_elapsed = time.time() - start_time

                if valid is not None:
                    all_preds = []
                    all_y = []
                    for batch_id, d in enumerate(valid_loader):
                        # print(batch_id)
                        for k in d.keys():
                            if k in ['long_keys']:
                                d[k] = d[k].to(self.config['device']).long()
                            else:
                                d[k] = d[k].to(self.config['device']).float()
                        preds = self.nn(d).detach().cpu().numpy()
                        all_preds.append(preds)
                        batch_y = d['y']
                        all_y.append(batch_y.cpu().numpy())

                    val_preds = np.concatenate(all_preds)
                    val_y = np.concatenate(all_y)
                    error = val_y - val_preds
                    if self.distance_model:
                        val_loss = np.sqrt((error[:, 0]**2).mean())
                        val_mae = np.abs(error[:, 0]).mean()
                    else:
                        distance_errors = np.sqrt((error**2).sum(1))
                        val_loss = distance_errors.mean()
                        val_mae = np.abs(error).mean()
                    met_hist.append(val_loss)
                    if val_loss < best_valid:
                        best_valid = val_loss
                        torch.save(self.nn.state_dict(),
                                   model_save_path,
                                   _use_new_zipfile_serialization=False)
                    elapsed = time.time() - start_time
                    # import pdb; pdb.set_trace()
                    if self.num_independent_segments == 1:
                        print(f"{epoch:3}: {train_loss:8.4f} {val_loss:8.4f}\
 {val_mae:8.4f} {train_elapsed:8.2f}s {elapsed:8.2f}s")
                    else:
                        print(f"{epoch:3}: {train_loss:8.4f} {val_loss:8.4f}\
 {train_elapsed:8.2f}s {elapsed:8.2f}s")
                    self.metric_history = met_hist
                else:
                    print(f"{epoch:3}: {train_loss:8.4f} {train_mae:8.4f}\
 {train_elapsed:8.2f}s")
                    if train_loss < best_train:
                        best_train = train_loss
                        torch.save(self.nn.state_dict(),
                                   model_save_path,
                                   _use_new_zipfile_serialization=False)

            del train_loader
            del valid_loader
            gc.collect()

        # Predict phase
        if predict is not None:
            waypoint_counts = np.array(
                [predict[1][k]['num_waypoints'] for k in predict[1]])
            if self.num_independent_segments == 1:
                num_predict = (waypoint_counts - 1).sum()
            else:
                num_predict = (waypoint_counts - 2).sum()
            predict_loader, predict_sub_trajectory_keys = self.get_data_loader(
                'test', predict[1], num_predict, fixed_order=True)

            all_preds = []
            all_y = []
            for batch_id, d in enumerate(predict_loader):
                print(batch_id)
                for k in d.keys():
                    if k in ['long_keys']:
                        d[k] = d[k].to(self.config['device']).long()
                    else:
                        d[k] = d[k].to(self.config['device']).float()
                preds = self.nn(d).detach().cpu().numpy()
                all_preds.append(preds)
                batch_y = d['y']
                all_y.append(batch_y.cpu().numpy())

            predict_preds = np.concatenate(all_preds)
            predict_y = np.concatenate(all_y)
            fns, sub_trajectory_ids = zip(*predict_sub_trajectory_keys)
            fns = np.array(fns)
            sub_trajectory_ids = np.array(sub_trajectory_ids)

            if self.num_independent_segments == 2:
                last_fn_ids = np.concatenate(
                    [fns[:-1] != fns[1:],
                     np.array([True])])
                fns = fns[~last_fn_ids]
                sub_trajectory_ids = sub_trajectory_ids[~last_fn_ids]
                sub_trajectory_ids += 1

            sites = []
            floors = []
            text_levels = []
            num_waypoints = []
            test_floors = utils.get_test_floors(utils.get_data_folder())
            for fn in fns:
                target_row = np.where(self.df.fn == fn)[0][0]
                sites.append(self.df.site_id.values[target_row])
                floors.append(
                    test_floors.get(fn, self.df.level.values[target_row]))
                text_levels.append(self.df.text_level.values[target_row])
                if self.df['mode'].values[target_row] == 'train':
                    num_waypoints.append(
                        self.df.num_train_waypoints.values[target_row])
                else:
                    num_waypoints.append(
                        self.df.num_test_waypoints.values[target_row])

            predictions = pd.DataFrame({
                'site': sites,
                'floor': floors,
                'text_level': text_levels,
                'fn': fns,
                'sub_trajectory_id': sub_trajectory_ids,
                'num_waypoints': num_waypoints,
            })

            if self.distance_model:
                fns = predictions.fn[predictions.sub_trajectory_id == 0].values
                predictions['fraction_time_covered'] = np.concatenate(
                    [predict[1][fn]['fractions_time_covered'] for fn in fns])
                predictions['prediction'] = predict_preds[:, 0]
                predictions['actual'] = predict_y[:, 0]
            else:
                predictions['prediction_x'] = predict_preds[:, 0]
                predictions['prediction_y'] = predict_preds[:, 1]
                predictions['actual_x'] = predict_y[:, 0]
                predictions['actual_y'] = predict_y[:, 1]
            predictions_path = model_folder / 'predictions' / (
                predict[0] + ' - ' + model_str + '.csv')
            # import pdb; pdb.set_trace()
            predictions.to_csv(predictions_path, index=False)