Beispiel #1
0
def test_stick_transition():
    K = 3
    D = 2
    M = 0

    model = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=100))
    print("alpha = {}, kappa = {}".format(model.transition.alpha, model.transition.kappa))

    data = torch.tensor([[2,3], [4,5], [6, 7]], dtype=torch.float64)

    log_prob = model.log_likelihood(data)
    print("log_prob", log_prob)

    samples_z, samples_x = model.sample(10)

    samples_log_prob = model.log_probability(samples_x)
    print("sample log_prob", samples_log_prob)

    print("model transition\n", model.transition.stationary_transition_matrix)

    model1 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=0.1))

    print("Before fit")
    print("model1 transition\n", model1.transition.stationary_transition_matrix)

    losses, opt = model1.fit(samples_x)
    print("After fit")
    print("model1 transition\n", model1.transition.stationary_transition_matrix)

    model2 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=20))

    print("Before fit")
    print("model2 transition\n", model2.transition.stationary_transition_matrix)

    losses, opt = model2.fit(samples_x)
    print("After fit")
    print("model2 transition\n", model2.transition.stationary_transition_matrix)

    model3 = HMM(K=K, D=D, M=M, transition='sticky', observation='gaussian', transition_kwargs=dict(alpha=1, kappa=100))

    print("Before fit")
    print("model3 transition\n", model3.transition.stationary_transition_matrix)

    losses, opt = model3.fit(samples_x)
    print("After fit")
    print("model3 transition\n", model3.transition.stationary_transition_matrix)
def main(job_name, downsample_n, filter_traj, load_model, load_model_dir, load_opt_dir,
         transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y,
         train_model,  pbar_update_interval, video_clips, torch_seed, np_seed,
         list_of_num_iters, list_of_lr, sample_t, quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")
    K = k
    sample_T = sample_t
    video_clip_start, video_clip_end = [int(x) for x in video_clips.split(",")]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    assert len(list_of_num_iters) == len(list_of_lr), "Length of list_of_num_iters must match length of list-of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    repo = git.Repo('.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
    trajs = joblib.load(data_dir)

    traj = trajs[36000*video_clip_start:36000*video_clip_end]
    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64)

    ######################### model ####################

    # model
    D = 4
    M = 0
    Df = 4

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)
        tran = model.observation.transformation

        K = model.K

        n_x = len(tran.x_grids) - 1
        n_y = len(tran.y_grids) - 1

        acc_factor = tran.transformations_a[0].acc_factor

    else:
        print("Creating the model...")
        bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX],
                           [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])

        # grids
        if x_grids is None:
            x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
            x_grids = [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)]
        else:
            x_grids = [float(x) for x in x_grids.split(",")]
            n_x = len(x_grids) - 1

        if y_grids is None:
            y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
            y_grids = [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)]
        else:
            y_grids = [float(x) for x in y_grids.split(",")]
            n_y = len(y_grids) - 1

        if acc_factor is None:
            acc_factor = downsample_n * 10

        tran = GridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, unit_transformation="direction",
                                  Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor)
        obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        else:
            transition_kwargs = None
        model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs)
        model.observation.mus_init = data[0] * torch.ones(K, D, dtype=torch.float64)

    # save experiment params
    exp_params = {"job_name":   job_name,
                  'downsample_n': downsample_n,
                  "load_model": load_model,
                  "load_model_dir": load_model_dir,
                  "load_opt_dir": load_opt_dir,
                  "transition": transition,
                  "sticky_alpha": sticky_alpha,
                  "sticky_kappa": sticky_kappa,
                  "acc_factor": acc_factor,
                  "K": K,
                  "n_x": n_x,
                  "n_y": n_y,
                  "x_grids": x_grids,
                  "y_grids": y_grids,
                  "train_model": train_model,
                  "pbar_update_interval": pbar_update_interval,
                  "list_of_num_iters": list_of_num_iters,
                  "list_of_lr": list_of_lr,
                  "video_clip_start": video_clip_start,
                  "video_clip_end": video_clip_end,
                  "sample_T": sample_T}

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/" + job_name)
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir+"/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memories
    masks_a, masks_b = tran.get_masks(data[:-1])
    feature_vecs_a = f_corner_vec_func(data[:-1, 0:2])
    feature_vecs_b = f_corner_vec_func(data[:-1, 2:4])

    m_kwargs_a = dict(feature_vecs=feature_vecs_a)
    m_kwargs_b = dict(feature_vecs=feature_vecs_b)


    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)):
            losses, opt = model.fit(data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr,
                                    masks=(masks_a, masks_b), pbar_update_interval=pbar_update_interval,
                                    memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)
            list_of_losses.append(losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir+"/model")
            joblib.dump(opt, checkpoint_dir+"/optimizer")
            # save rest
            rslt_saving(checkpoint_dir, model, Df, data, masks_a, masks_b, m_kwargs_a, m_kwargs_b, sample_T,
                        train_model, losses, quiver_scale)

    else:
        # only save the results
        rslt_saving(rslt_dir, model, Df, data, masks_a, masks_b, m_kwargs_a, m_kwargs_b, sample_T,
                    False, [], quiver_scale)

    print("Finish running!")
Beispiel #3
0
momentum_vecs = MomentumInteractionTransformation._compute_momentum_vecs(
    data[:-1], lags=momentum_lags)
interaction_vecs = MomentumInteractionTransformation._compute_interaction_vecs(
    data[:-1])

out = model.log_likelihood(data,
                           momentum_vecs=momentum_vecs,
                           interaction_vecs=interaction_vecs)
print(out)

##################### training ############################

num_iters = 10
losses, opt = model.fit(data,
                        num_iters=num_iters,
                        lr=0.001,
                        momentum_vecs=momentum_vecs,
                        interaction_vecs=interaction_vecs)

##################### sampling ############################
print("start sampling")
sample_z, sample_x = model.sample(30)

#################### inference ###########################
print("inferiring most likely states...")
z = model.most_likely_states(data,
                             momentum_vecs=momentum_vecs,
                             interaction_vecs=interaction_vecs)

#print("k step prediction")
#x_predict = k_step_prediction_for_coupled_momentum_model(model, z, data, momentum_vecs=momentum_vecs, features=features)
Beispiel #4
0
def main(job_name, cuda_num, downsample_n, filter_traj, gp_version, load_model,
         load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa,
         acc_factor, k, x_grids, y_grids, n_x, n_y, rs_factor, rs, train_rs,
         train_model, pbar_update_interval, video_clips, held_out_proportion,
         torch_seed, np_seed, list_of_num_iters, ckpts_not_to_save, list_of_lr,
         list_of_k_steps, sample_t, quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")

    cuda_num = int(cuda_num)
    device = torch.device(
        "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu")
    print("Using device {} \n\n".format(device))

    K = k
    sample_T = sample_t
    rs_factor = np.array([float(x) for x in rs_factor.split(",")])
    if rs_factor[0] == 0 and rs_factor[1] == 0:
        rs_factor = None
    rs = float(rs)
    video_clip_start, video_clip_end = [
        float(x) for x in video_clips.split(",")
    ]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")]
    assert len(list_of_num_iters) == len(
        list_of_lr
    ), "Length of list_of_num_iters must match length of list_of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',')
                         ] if ckpts_not_to_save else []

    repo = git.Repo(
        '.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
    trajs = joblib.load(data_dir)

    traj = trajs[int(36000 * video_clip_start):int(36000 * video_clip_end)]
    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64, device=device)
    assert 0 <= held_out_proportion <= 0.4, \
        "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion)
    T = data.shape[0]
    breakpoint = int(T * (1 - held_out_proportion))
    training_data = data[:breakpoint]
    valid_data = data[breakpoint:]

    ######################### model ####################

    # model
    D = 4
    M = 0
    Df = 4

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)
        tran = model.observation.transformation

        K = model.K

        n_x = len(tran.x_grids) - 1
        n_y = len(tran.y_grids) - 1

        acc_factor = tran.acc_factor

    else:
        print("Creating the model...")
        bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX],
                           [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])

        # grids
        if x_grids is None:
            x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
            x_grids = np.array(
                [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
        else:
            x_grids = np.array([float(x) for x in x_grids.split(",")])
            n_x = len(x_grids) - 1

        if y_grids is None:
            y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
            y_grids = np.array(
                [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
        else:
            y_grids = np.array([float(x) for x in y_grids.split(",")])
            n_y = len(y_grids) - 1

        if acc_factor is None:
            acc_factor = downsample_n * 10

        tran = GPGridTransformation(K=K,
                                    D=D,
                                    x_grids=x_grids,
                                    y_grids=y_grids,
                                    Df=Df,
                                    feature_vec_func=f_corner_vec_func,
                                    acc_factor=acc_factor,
                                    rs_factor=rs_factor,
                                    rs=None,
                                    train_rs=train_rs,
                                    device=device,
                                    version=gp_version)
        obs = ARTruncatedNormalObservation(K=K,
                                           D=D,
                                           M=M,
                                           lags=1,
                                           bounds=bounds,
                                           transformation=tran,
                                           device=device)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        else:
            transition_kwargs = None
        model = HMM(K=K,
                    D=D,
                    M=M,
                    transition=transition,
                    observation=obs,
                    transition_kwargs=transition_kwargs,
                    device=device)
        model.observation.mus_init = training_data[0] * torch.ones(
            K, D, dtype=torch.float64, device=device)

    # save experiment params
    exp_params = {
        "job_name": job_name,
        'downsample_n': downsample_n,
        "filter_traj": filter_traj,
        "load_model": load_model,
        "gp_version": gp_version,
        "load_model_dir": load_model_dir,
        "load_opt_dir": load_opt_dir,
        "transition": transition,
        "sticky_alpha": sticky_alpha,
        "sticky_kappa": sticky_kappa,
        "acc_factor": acc_factor,
        "K": K,
        "x_grids": x_grids,
        "y_grids": y_grids,
        "n_x": n_x,
        "n_y": n_y,
        "rs_factor": get_np(tran.rs_factor),
        "rs": rs,
        "train_rs": train_rs,
        "train_model": train_model,
        "pbar_update_interval": pbar_update_interval,
        "video_clip_start": video_clip_start,
        "video_clip_end": video_clip_end,
        "held_out_proportion": held_out_proportion,
        "torch_seed": torch_seed,
        "np_seed": np_seed,
        "list_of_num_iters": list_of_num_iters,
        "list_of_lr": list_of_lr,
        "list_of_k_steps": list_of_k_steps,
        "sample_T": sample_T,
        "quiver_scale": quiver_scale
    }

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/gpgrid/" + job_name)
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir + "/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memory
    print("Computing memory...")

    def get_memory_kwargs(data, train_rs):
        feature_vecs_a = f_corner_vec_func(data[:-1, 0:2])
        feature_vecs_b = f_corner_vec_func(data[:-1, 2:4])

        gpt_idx_a, grid_idx_a = tran.get_gpt_idx_and_grid_idx_for_batch(
            data[:-1, 0:2])
        gpt_idx_b, grid_idx_b = tran.get_gpt_idx_and_grid_idx_for_batch(
            data[:-1, 2:4])

        if train_rs:
            nearby_gpts_a = tran.gridpoints[gpt_idx_a]
            dist_sq_a = (data[:-1, None, 0:2] - nearby_gpts_a)**2

            nearby_gpts_b = tran.gridpoints[gpt_idx_b]
            dist_sq_b = (data[:-1, None, 2:4] - nearby_gpts_b)**2

            return dict(feature_vecs_a=feature_vecs_a,
                        feature_vecs_b=feature_vecs_b,
                        gpt_idx_a=gpt_idx_a,
                        gpt_idx_b=gpt_idx_b,
                        grid_idx_a=grid_idx_a,
                        grid_idx_b=grid_idx_b,
                        dist_sq_a=dist_sq_a,
                        dist_sq_b=dist_sq_b)

        else:
            coeff_a = tran.get_gp_coefficients(data[:-1, 0:2], 0, gpt_idx_a,
                                               grid_idx_a)
            coeff_b = tran.get_gp_coefficients(data[:-1, 2:4], 0, gpt_idx_b,
                                               grid_idx_b)

            return dict(feature_vecs_a=feature_vecs_a,
                        feature_vecs_b=feature_vecs_b,
                        gpt_idx_a=gpt_idx_a,
                        gpt_idx_b=gpt_idx_b,
                        grid_idx_a=grid_idx_a,
                        grid_idx_b=grid_idx_b,
                        coeff_a=coeff_a,
                        coeff_b=coeff_b)

    memory_kwargs = get_memory_kwargs(training_data, train_rs)
    valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs)

    log_prob = model.log_likelihood(training_data, **memory_kwargs)

    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters,
                                                list_of_lr)):
            training_losses, opt, valid_losses = model.fit(
                training_data,
                optimizer=opt,
                method='adam',
                num_iters=num_iters,
                lr=lr,
                pbar_update_interval=pbar_update_interval,
                valid_data=valid_data,
                valid_data_memory_kwargs=valid_data_memory_kwargs,
                **memory_kwargs)
            list_of_losses.append(training_losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir + "/model")
            joblib.dump(opt, checkpoint_dir + "/optimizer")

            # save losses
            losses = dict(training_loss=training_losses,
                          valid_loss=valid_losses)
            joblib.dump(losses, checkpoint_dir + "/losses")

            plt.figure()
            plt.plot(training_losses)
            plt.title("training loss")
            plt.savefig(checkpoint_dir + "/training_losses.jpg")
            plt.close()

            plt.figure()
            plt.plot(valid_losses)
            plt.title("validation loss")
            plt.savefig(checkpoint_dir + "/valid_losses.jpg")
            plt.close()

            # save rest
            if i in ckpts_not_to_save:
                print("ckpt {}: skip!\n".format(i))
                continue
            with torch.no_grad():
                rslt_saving(rslt_dir=checkpoint_dir,
                            model=model,
                            data=training_data,
                            memory_kwargs=memory_kwargs,
                            list_of_k_steps=list_of_k_steps,
                            sample_T=sample_T,
                            quiver_scale=quiver_scale,
                            valid_data=valid_data,
                            valid_data_memory_kwargs=valid_data_memory_kwargs,
                            device=device)

    else:
        # only save the results
        rslt_saving(rslt_dir=rslt_dir,
                    model=model,
                    data=training_data,
                    memory_kwargs=memory_kwargs,
                    list_of_k_steps=list_of_k_steps,
                    sample_T=sample_T,
                    quiver_scale=quiver_scale,
                    valid_data=valid_data,
                    valid_data_memory_kwargs=valid_data_memory_kwargs,
                    device=device)

    print("Finish running!")
Beispiel #5
0
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

# test fitting

K = 3
D = 2
lags= 5

trans1 = LinearTransformation(K=K, D=D, lags=lags)
obs1 = ARGaussianObservation(K=K, D=D, transformation=trans1)
model1 = HMM(K=K,D=D, observation=obs1)

T = 100
sample_z, sample_x = model1.sample(T)

model2 = HMM(K=K, D=D, observation='gaussian', observation_kwargs=dict(lags=lags))

lls, opt = model2.fit(sample_x, num_iters=2000, lr=0.001)

z_infer = model2.most_likely_states(sample_x)

x_predict = k_step_prediction(model2, z_infer, sample_x)

plt.figure()
plt.plot(x_predict[:,0], label='prediction')
plt.plot(sample_x[:,0], label='truth')
plt.show()
Beispiel #6
0
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 100
    D = 4

    # data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
    #                [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0], [8.0, 2.0, 6.0, 1.0]])
    data = np.random.randn(T, D)
    data = torch.tensor(data, dtype=torch.float64)

    xmax = max(np.max(data[:, 0].numpy()), np.max(data[:, 2].numpy()))
    xmin = min(np.min(data[:, 0].numpy()), np.min(data[:, 2].numpy()))
    ymax = max(np.max(data[:, 1].numpy()), np.max(data[:, 3].numpy()))
    ymin = min(np.min(data[:, 1].numpy()), np.min(data[:, 3].numpy()))
    bounds = np.array([[xmin - 1, xmax + 1], [ymin - 1, ymax + 1],
                       [xmin - 1, xmax + 1], [ymin - 1, ymax + 1]])

    def toy_feature_vec_func(s):
        """
        :param s: self, (T, 2)
        :param o: other, (T, 2)
        :return: features, (T, Df, 2)
        """
        corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]],
                               dtype=torch.float64)
        return feature_direction_vec(s, corners)

    K = 3

    Df = 4
    lags = 1

    tran = UniLSTMTransformation(K=K,
                                 D=D,
                                 Df=Df,
                                 feature_vec_func=toy_feature_vec_func,
                                 lags=lags,
                                 dh=10)

    # observation
    obs = ARTruncatedNormalObservation(K=K,
                                       D=D,
                                       lags=lags,
                                       bounds=bounds,
                                       transformation=tran)

    # model
    model = HMM(K=K, D=D, observation=obs)

    print("calculating log likelihood")
    feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2])
    feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4])
    feature_vecs = (feature_vecs_a, feature_vecs_b)
    packed_data = get_packed_data((data[:-1]), lags=lags)

    model.log_likelihood(data,
                         feature_vecs=feature_vecs,
                         packed_data=packed_data)

    # fit
    losses, _ = model.fit(data,
                          optimizer=None,
                          method="adam",
                          num_iters=50,
                          feature_vecs=feature_vecs,
                          packed_data=packed_data)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data,
                                 feature_vecs=feature_vecs,
                                 packed_data=packed_data)

    # prediction

    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]

    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_lstm_model(model,
                                                 z,
                                                 data_to_predict,
                                                 feature_vecs=feature_vecs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("10 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=10)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[10:].numpy()),
                              axis=0)

    # samples
    print("sampling...")
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 5
    x_grids = np.array([0.0, 5.0, 10.0])
    y_grids = np.array([0.0, 4.0, 8.0])

    data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
                     [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0],
                     [8.0, 2.0, 6.0, 1.0]])
    data = torch.tensor(data, dtype=torch.float64)

    def toy_feature_vec_func(s):
        """
        :param s: self, (T, 2)
        :param o: other, (T, 2)
        :return: features, (T, Df, 2)
        """
        corners = torch.tensor([[0, 0], [0, 8], [10, 0], [10, 8]],
                               dtype=torch.float64)
        return feature_direction_vec(s, corners)

    K = 3
    D = 4
    M = 0

    Df = 4

    bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]])
    tran = LinearGridTransformation(K=K,
                                    D=D,
                                    x_grids=x_grids,
                                    y_grids=y_grids,
                                    Df=Df,
                                    feature_vec_func=toy_feature_vec_func)
    obs = ARTruncatedNormalObservation(K=K,
                                       D=D,
                                       M=0,
                                       lags=1,
                                       bounds=bounds,
                                       transformation=tran)

    model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs)
    model.observation.mus_init = data[0] * torch.ones(
        K, D, dtype=torch.float64)

    # calculate memory
    gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(data[:-1, 0:2])
    gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(data[:-1, 2:4])
    gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a)
    gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b)
    feature_vecs_a = toy_feature_vec_func(data[:-1, 0:2])
    feature_vecs_b = toy_feature_vec_func(data[:-1, 2:4])

    gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b)
    gridpoints = (gridpoints_a, gridpoints_b)
    feature_vecs = (feature_vecs_a, feature_vecs_b)

    # fit
    losses, opt = model.fit(data,
                            optimizer=None,
                            method='adam',
                            num_iters=100,
                            lr=0.01,
                            pbar_update_interval=10,
                            gridpoints=gridpoints,
                            gridpoints_idx=gridpoints_idx,
                            feature_vecs=feature_vecs)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data,
                                 gridpoints_idx=gridpoints_idx,
                                 feature_vecs=feature_vecs)

    # prediction
    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_lineargrid_model(
        model,
        z,
        data_to_predict,
        gridpoints_idx=gridpoints_idx,
        feature_vecs=feature_vecs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("2 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[2:].numpy()),
                              axis=0)

    # samples
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
Beispiel #8
0
def test_model():
    torch.manual_seed(0)
    np.random.seed(0)

    T = 5
    x_grids = np.array([0.0, 10.0])
    y_grids = np.array([0.0, 8.0])

    bounds = np.array([[0.0, 10.0], [0.0, 8.0], [0.0, 10.0], [0.0, 8.0]])
    data = np.array([[1.0, 1.0, 1.0, 6.0], [3.0, 6.0, 8.0, 6.0],
                     [4.0, 7.0, 8.0, 5.0], [6.0, 7.0, 5.0, 6.0],
                     [8.0, 2.0, 6.0, 1.0]])
    data = torch.tensor(data, dtype=torch.float64)

    K = 3
    D = 4
    M = 0

    obs = GPObservation(K=K,
                        D=D,
                        x_grids=x_grids,
                        y_grids=y_grids,
                        bounds=bounds,
                        train_rs=True)

    correct_kerneldist_gg = torch.tensor(
        [[0., 0., 64., 64., 100., 100., 164., 164.],
         [0., 0., 64., 64., 100., 100., 164., 164.],
         [64., 64., 0., 0., 164., 164., 100., 100.],
         [64., 64., 0., 0., 164., 164., 100., 100.],
         [100., 100., 164., 164., 0., 0., 64., 64.],
         [100., 100., 164., 164., 0., 0., 64., 64.],
         [164., 164., 100., 100., 64., 64., 0., 0.],
         [164., 164., 100., 100., 64., 64., 0., 0.]],
        dtype=torch.float64)
    assert torch.all(torch.eq(correct_kerneldist_gg,
                              obs.kernel_distsq_gg)), obs.kernel_distsq_gg

    log_prob_nocache = obs.log_prob(data)
    print("log_prob_nocache = {}".format(log_prob_nocache))

    kernel_distsq_xg_a = kernel_distsq_doubled(data[:-1, 0:2],
                                               obs.inducing_points)
    kernel_distsq_xg_b = kernel_distsq_doubled(data[:-1, 2:4],
                                               obs.inducing_points)

    correct_kernel_distsq_xg_a = torch.tensor(
        [[2., 2., 50., 50., 82., 82., 130., 130.],
         [2., 2., 50., 50., 82., 82., 130., 130.],
         [45., 45., 13., 13., 85., 85., 53., 53.],
         [45., 45., 13., 13., 85., 85., 53., 53.],
         [65., 65., 17., 17., 85., 85., 37., 37.],
         [65., 65., 17., 17., 85., 85., 37., 37.],
         [85., 85., 37., 37., 65., 65., 17., 17.],
         [85., 85., 37., 37., 65., 65., 17., 17.]],
        dtype=torch.float64)
    assert torch.all(torch.eq(correct_kernel_distsq_xg_a,
                              kernel_distsq_xg_a)), kernel_distsq_xg_a

    memory_kwargs = dict(kernel_distsq_xg_a=kernel_distsq_xg_a,
                         kernel_distsq_xg_b=kernel_distsq_xg_b)

    log_prob = obs.log_prob(data, **memory_kwargs)
    print("log_prob = {}".format(log_prob))

    assert torch.all(torch.eq(log_prob_nocache, log_prob))

    Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0, **memory_kwargs)
    Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1, **memory_kwargs)
    memory_kwargs_2 = dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b)

    print("calculating log prob 2...")
    log_prob2 = obs.log_prob(data, **memory_kwargs_2)
    assert torch.all(torch.eq(log_prob, log_prob2))

    model = HMM(K=K, D=D, M=M, transition="stationary", observation=obs)
    model.observation.mus_init = data[0] * torch.ones(
        K, D, dtype=torch.float64)

    # fit
    losses, opt = model.fit(data,
                            optimizer=None,
                            method='adam',
                            num_iters=100,
                            lr=0.01,
                            pbar_update_interval=10,
                            **memory_kwargs)

    plt.figure()
    plt.plot(losses)
    plt.show()

    # most-likely-z
    print("Most likely z...")
    z = model.most_likely_states(data, **memory_kwargs)

    # prediction
    print("0 step prediction")
    if data.shape[0] <= 1000:
        data_to_predict = data
    else:
        data_to_predict = data[-1000:]
    x_predict = k_step_prediction_for_gpmodel(model, z, data_to_predict,
                                              **memory_kwargs)
    x_predict_err = np.mean(np.abs(x_predict - data_to_predict.numpy()),
                            axis=0)

    print("2 step prediction")
    x_predict_2 = k_step_prediction(model, z, data_to_predict, k=2)
    x_predict_2_err = np.mean(np.abs(x_predict_2 -
                                     data_to_predict[2:].numpy()),
                              axis=0)

    # samples
    sample_T = 5
    sample_z, sample_x = model.sample(sample_T)
# observation
obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=momentum_lags, bounds=bounds, transformation=tran)

# model
model = HMM(K=K, D=D, M=M, observation=obs)

log_prob = model.log_likelihood(data, masks=(masks_a, masks_b),
                                memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)
log_prob_2 = model.log_likelihood(data)
assert torch.eq(log_prob, log_prob_2)
print(log_prob)

# training
print("start training...")
num_iters = 10
losses, opt = model.fit(data, num_iters=num_iters, lr=0.001, masks=(masks_a, masks_b),
                        memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

# sampling
print("start sampling")
sample_z, sample_x = model.sample(T)

# inference
print("inferiring most likely states...")
z = model.most_likely_states(data, masks=(masks_a, masks_b),
                             memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

print("0 step prediction")
x_predict = k_step_prediction_for_grid_model(model, z, data, memory_kwargs_a=m_kwargs_a, memory_kwargs_b=m_kwargs_b)

print("k step prediction")
x_predict_10 = k_step_prediction(model, z, data, 10)
Beispiel #10
0
momentum_vecs = MomentumFeatureTransformation._compute_momentum_vecs(
    data[:-1], lags=momentum_lags)
features = MomentumFeatureTransformation._compute_features(
    feature_funcs=feature_func_single, inputs=data[:-1])

out = model.log_likelihood(data,
                           momentum_vecs=momentum_vecs,
                           features=features)
print(out)

##################### training ############################

num_iters = 10
losses, opt = model.fit(data,
                        num_iters=num_iters,
                        lr=0.001,
                        momentum_vecs=momentum_vecs,
                        features=features)

##################### sampling ############################
print("start sampling")
sample_z, sample_x = model.sample(30)

#################### inference ###########################
print("inferiring most likely states...")
z = model.most_likely_states(data,
                             momentum_vecs=momentum_vecs,
                             features=features)

print("k step prediction")
x_predict = k_step_prediction_for_momentum_feature_model(
Beispiel #11
0
def main(job_name, cuda_num, data_type, downsample_n, filter_traj, lg_version, use_log_prior, no_boundary_prior,
         add_log_diagonal_prior, log_prior_sigma_sq,
         load_model, load_model_dir, load_opt_dir, animal, reset_prior_info,
         transition, sticky_alpha, sticky_kappa, acc_factor, k, x_grids, y_grids, n_x, n_y,
         train_model, pbar_update_interval, prop_start_end, video_clips, held_out_proportion, torch_seed, np_seed,
         list_of_num_iters, ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t, quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")
    assert animal in ['both', 'virgin', 'mother'], animal

    cuda_num = int(cuda_num)
    device = torch.device("cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu")
    print("Using device {} \n\n".format(device))

    K = k
    sample_T = sample_t
    video_clip_start, video_clip_end = [float(x) for x in video_clips.split(",")]
    start, end = [float(x) for x in prop_start_end.split(",")]
    log_prior_sigma_sq = float(log_prior_sigma_sq)
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    # TODO: test if that works
    list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")] if list_of_k_steps else []
    assert len(list_of_num_iters) == len(list_of_lr), "Length of list_of_num_iters must match length of list_of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',')] if ckpts_not_to_save else []

    repo = git.Repo('.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    if data_type == 'full':
        data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
        traj = joblib.load(data_dir)
        traj = traj[int(36000 * video_clip_start):int(36000 * video_clip_end)]
    elif data_type == 'selected_010_virgin':
        assert animal == 'virgin', "animal much be 'virgin', but got {}.".format(animal)
        data_dir = repo_dir + '/SocialBehaviorptc/data/traj_010_virgin_selected'
        traj = joblib.load(data_dir)
        T = len(traj)
        traj = traj[int(T*start): int(T*end)]
    else:
        raise ValueError("unsupported data type: {}".format(data_type))

    if animal == 'virgin':
        traj = traj[:,0:2]
    elif animal == 'mother':
        traj = traj[:,2:4]

    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64, device=device)
    assert 0 <= held_out_proportion <= 0.4, \
        "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion)
    T = data.shape[0]
    breakpoint = int(T*(1-held_out_proportion))
    training_data = data[:breakpoint]
    valid_data = data[breakpoint:]

    sample_T = training_data.shape[0]

    ######################### model ####################

    # model
    D = data.shape[1]
    assert D == 2 or 4, D
    M = 0
    Df = 4

    if load_model:
        print("Loading the model from ", load_model_dir)
        if reset_prior_info:

            pretrained_model = joblib.load(load_model_dir)
            pretrained_transition = pretrained_model.transition
            pretrained_observation = pretrained_model.observation
            pretrained_tran = pretrained_model.observation.transformation

            # set prior info
            pretrained_tran.use_log_prior = use_log_prior
            pretrained_tran.no_boundary_prior = no_boundary_prior
            pretrained_tran.add_log_diagonal_prior = add_log_diagonal_prior
            pretrained_tran.log_prior_sigma_sq = torch.tensor(log_prior_sigma_sq, dtype=torch.float64, device=device)

            acc_factor = pretrained_tran.acc_factor

            K = pretrained_model.K

            obs = ARTruncatedNormalObservation(K=K, D=D, M=0, obs=pretrained_observation, device=device)
            tran = obs.transformation

            if transition == 'sticky':
                transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
            else:
                transition_kwargs = None
            model = HMM(K=K, D=D, M=M, pi0=get_np(pretrained_model.pi0), Pi=get_np(pretrained_transition.Pi),
                        transition=transition, observation=obs, transition_kwargs=transition_kwargs,
                        device=device)
            model.observation.mus_init = training_data[0] * torch.ones(K, D, dtype=torch.float64, device=device)
        else:
            model = joblib.load(load_model_dir)
            tran = model.observation.transformation

            K = model.K

            n_x = len(tran.x_grids) - 1
            n_y = len(tran.y_grids) - 1
            if isinstance(model, LinearGridTransformation):
                acc_factor = tran.acc_factor
                use_log_prior = tran.use_log_prior
                no_boundary_prior = tran.no_boundary_prior
                add_log_diagonal_prior = tran.add_log_diagonal_prior
                log_prior_sigma_sq = get_np(tran.log_prior_sigma_sq)

    else:
        if D == 4:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX],
                               [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])
        else:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])

        # grids
        if x_grids is None:
            x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
            x_grids = np.array([ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
        else:
            x_grids = np.array([float(x) for x in x_grids.split(",")])
            n_x = len(x_grids) - 1

        if y_grids is None:
            y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
            y_grids = np.array([ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
        else:
            y_grids = np.array([float(x) for x in y_grids.split(",")])
            n_y = len(y_grids) - 1

        if acc_factor is None:
            acc_factor = downsample_n * 10

        print("Creating the model...")
        if animal == 'both':
            tran = LinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids,
                                            Df=Df, feature_vec_func=f_corner_vec_func, acc_factor=acc_factor,
                                            use_log_prior=use_log_prior, no_boundary_prior=no_boundary_prior,
                                            add_log_diagonal_prior=add_log_diagonal_prior,
                                            log_prior_sigma_sq=log_prior_sigma_sq, device=device, version=lg_version)
        else:
            tran = SingleLinearGridTransformation(K=K, D=D, x_grids=x_grids, y_grids=y_grids, device=device)
        obs = ARTruncatedNormalObservation(K=K, D=D, M=M, lags=1, bounds=bounds, transformation=tran, device=device)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        elif transition == 'grid':
            transition_kwargs = dict(x_grids=x_grids, y_grids=y_grids)
        else:
            transition_kwargs = None
        model = HMM(K=K, D=D, M=M, transition=transition, observation=obs, transition_kwargs=transition_kwargs,
                    device=device)
        model.observation.mus_init = training_data[0] * torch.ones(K, D, dtype=torch.float64, device=device)

    # save experiment params
    exp_params = {"job_name":   job_name,
                  'downsample_n': downsample_n,
                  'data_type': data_type,
                  "filter_traj": filter_traj,
                  "lg_version": lg_version,
                  "use_log_prior": use_log_prior,
                  "add_log_diagonal_prior": add_log_diagonal_prior,
                  "no_boundary_prior": no_boundary_prior,
                  "log_prior_sigma_sq": log_prior_sigma_sq,
                  "load_model": load_model,
                  "load_model_dir": load_model_dir,
                  "animal": animal,
                  "reset_prior_info": reset_prior_info,
                  "load_opt_dir": load_opt_dir,
                  "transition": transition,
                  "sticky_alpha": sticky_alpha,
                  "sticky_kappa": sticky_kappa,
                  "acc_factor": acc_factor,
                  "K": K,
                  "x_grids": x_grids,
                  "y_grids": y_grids,
                  "n_x": n_x,
                  "n_y": n_y,
                  "train_model": train_model,
                  "pbar_update_interval": pbar_update_interval,
                  "video_clip_start": video_clip_start,
                  "video_clip_end": video_clip_end,
                  "start_percentage": start,
                  "end_percentage": end,
                  "held_out_proportion": held_out_proportion,
                  "torch_seed": torch_seed,
                  "np_seed": np_seed,
                  "list_of_num_iters": list_of_num_iters,
                  "ckpts_not_to_save": ckpts_not_to_save,
                  "list_of_lr": list_of_lr,
                  "list_of_k_steps": list_of_k_steps,
                  "sample_T": sample_T,
                  "quiver_scale": quiver_scale}

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/lineargrid/{}/{}".format(animal, job_name))
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir+"/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memory
    if transition == "grid":
        print("Computing transition memory...")
        joint_grid_idx = model.transition.get_grid_idx(training_data[:-1])
        transition_memory_kwargs = dict(joint_grid_idx=joint_grid_idx)
        valid_joint_grid_idx = model.transition.get_grid_idx(valid_data[:-1])
        valid_data_transition_memory_kwargs = dict(joint_grid_idx=valid_joint_grid_idx)
    else:
        transition_memory_kwargs = None
        valid_data_transition_memory_kwargs = None

    print("Computing observation memory...")
    if animal == 'both':
        gridpoints_idx_a = tran.get_gridpoints_idx_for_batch(training_data[:-1, 0:2])  # (T-1, n_gps, 4)
        gridpoints_idx_b = tran.get_gridpoints_idx_for_batch(training_data[:-1, 2:4])  # (T-1, n_gps, 4)
        gridpoints_a = tran.get_gridpoints_for_batch(gridpoints_idx_a)  # (T-1, d, 2)
        gridpoints_b = tran.get_gridpoints_for_batch(gridpoints_idx_b)  # (T-1, d, 2)
        feature_vecs_a = f_corner_vec_func(training_data[:-1, 0:2])  # (T, Df, 2)
        feature_vecs_b = f_corner_vec_func(training_data[:-1, 2:4])  # (T, Df, 2)

        gridpoints_idx = (gridpoints_idx_a, gridpoints_idx_b)
        gridpoints = (gridpoints_a, gridpoints_b)
        feature_vecs = (feature_vecs_a, feature_vecs_b)
        memory_kwargs = dict(gridpoints_idx=gridpoints_idx, gridpoints=gridpoints, feature_vecs=feature_vecs)

        if len(valid_data) > 0:
            gridpoints_idx_a_v = tran.get_gridpoints_idx_for_batch(valid_data[:-1, 0:2])  # (T-1, n_gps, 4)
            gridpoints_idx_b_v = tran.get_gridpoints_idx_for_batch(valid_data[:-1, 2:4])  # (T-1, n_gps, 4)
            gridpoints_a_v = tran.get_gridpoints_for_batch(gridpoints_idx_a_v)  # (T-1, d, 2)
            gridpoints_b_v = tran.get_gridpoints_for_batch(gridpoints_idx_b_v)  # (T-1, d, 2)
            feature_vecs_a_v = f_corner_vec_func(valid_data[:-1, 0:2])  # (T, Df, 2)
            feature_vecs_b_v = f_corner_vec_func(valid_data[:-1, 2:4])  # (T, Df, 2)

            gridpoints_idx_v = (gridpoints_idx_a_v, gridpoints_idx_b_v)
            gridpoints_v = (gridpoints_a_v, gridpoints_b_v)
            feature_vecs_v = (feature_vecs_a_v, feature_vecs_b_v)
            valid_data_memory_kwargs = dict(gridpoints_idx=gridpoints_idx_v, gridpoints=gridpoints_v,
                                            feature_vecs=feature_vecs_v)
        else:
            valid_data_memory_kwargs = {}
    else:
        def get_memory_kwargs(data):
            if data is None or data.shape[0] == 0:
                return {}

            gridpoints_idx = tran.get_gridpoints_idx_for_batch(data)
            gridpoints = tran.gridpoints[gridpoints_idx]
            coeffs = tran.get_lp_coefficients(data, gridpoints[:, 0], gridpoints[:, 3], device=device)

            return dict(gridpoints_idx=gridpoints_idx, coeffs=coeffs)

        memory_kwargs = get_memory_kwargs(training_data[:-1])
        valid_data_memory_kwargs = get_memory_kwargs(valid_data[:-1])

    log_prob = model.log_likelihood(training_data, transition_memory_kwargs=transition_memory_kwargs, **memory_kwargs)
    print("log_prob = {}".format(log_prob))

    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters, list_of_lr)):
            training_losses, opt, valid_losses, _ = \
                model.fit(training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr,
                          pbar_update_interval=pbar_update_interval, valid_data=valid_data,
                          transition_memory_kwargs=transition_memory_kwargs,
                          valid_data_transition_memory_kwargs=valid_data_transition_memory_kwargs,
                          valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs)

            list_of_losses.append(training_losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir+"/model")
            joblib.dump(opt, checkpoint_dir+"/optimizer")

            # save losses
            losses = dict(training_loss=training_losses, valid_loss=valid_losses)
            joblib.dump(losses, checkpoint_dir+"/losses")

            plt.figure()
            plt.plot(training_losses)
            plt.title("training loss")
            plt.savefig(checkpoint_dir + "/training_losses.jpg")
            plt.close()

            plt.figure()
            plt.plot(valid_losses)
            plt.title("validation loss")
            plt.savefig(checkpoint_dir + "/valid_losses.jpg")
            plt.close()
            # save rest
            if i in ckpts_not_to_save:
                print("ckpt {}: skip!\n".format(i))
                continue
            with torch.no_grad():
                rslt_saving(rslt_dir=checkpoint_dir, model=model, data=training_data, animal=animal,
                            memory_kwargs=memory_kwargs,
                            list_of_k_steps=list_of_k_steps, sample_T=sample_T,
                            quiver_scale=quiver_scale, valid_data=valid_data,
                            valid_data_memory_kwargs=valid_data_memory_kwargs, device=device)

    else:
        # only save the results
        rslt_saving(rslt_dir=rslt_dir, model=model, data=training_data, animal=animal,
                    memory_kwargs=memory_kwargs,
                    list_of_k_steps=list_of_k_steps, sample_T=sample_T,
                    quiver_scale=quiver_scale, valid_data=valid_data,
                    valid_data_memory_kwargs=valid_data_memory_kwargs, device=device)

    print("Finish running!")
D = 4
M = 0

obs = TruncatedNormalObservation(K=K, D=D, M=M, bounds=bounds)

# model
model = HMM(K=K, D=D, M=M, observation=obs)

# log like
log_prob = model.log_probability(data)
print("log probability = ", log_prob)

# training
print("start training...")
num_iters = 10
losses, opt = model.fit(data, num_iters=num_iters, lr=0.001)

# sampling
samplt_T = T
print("start sampling")
sample_z, sample_x = model.sample(samplt_T)

print("start sampling based on with_noise")
sample_z2, sample_x2 = model.sample(T, with_noise=True)

# inference
print("inferiring most likely states...")
z = model.most_likely_states(data)

print("0 step prediction")
x_predict = k_step_prediction(model, z, data)
Beispiel #13
0
def main(job_name, downsample_n, mouse, filter_traj, load_model,
         load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa,
         k, x_grids, y_grids, n_x, n_y, not_train_mu, initialize_mu,
         train_model, pbar_update_interval, video_clips, torch_seed, np_seed,
         list_of_num_iters, list_of_lr, sample_t, quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")
    K = k
    sample_T = sample_t
    video_clip_start, video_clip_end = [int(x) for x in video_clips.split(",")]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    assert len(list_of_num_iters) == len(
        list_of_lr
    ), "Length of list_of_num_iters must match length of list-of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    repo = git.Repo(
        '.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    # specify grids for sanity checks and plotting
    if x_grids is None:
        x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
        x_grids = np.array(
            [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
    else:
        x_grids = np.array([float(x) for x in x_grids.split(",")])
        n_x = len(x_grids) - 1

    if y_grids is None:
        y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
        y_grids = np.array(
            [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
    else:
        y_grids = np.array([float(x) for x in y_grids.split(",")])
        n_y = len(y_grids) - 1

    ########################## data ########################
    data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
    trajs = joblib.load(data_dir)

    traj = trajs[36000 * video_clip_start:36000 * video_clip_end]
    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    if mouse == "both":
        data = torch.tensor(traj, dtype=torch.float64)
        D = 4
    elif mouse == "virgin":
        data = torch.tensor(traj[:, 0:2], dtype=torch.float64)
        D = 2
    elif mouse == "mother":
        data = torch.tensor(traj[:, 2:4], dtype=torch.float64)
        D = 2
    else:
        raise ValueError(
            "mouse must be chosen from 'both', 'virgin', 'mother'.")

    ######################### model ####################

    # model
    M = 0

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)

        K = model.K

        cluster_locations = get_np(model.observation.mus)

    else:
        print("Creating the model...")
        if D == 4:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX],
                               [ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX]])
        else:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX]])

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        else:
            transition_kwargs = None

        obs = DynamicLocationObservation(K=K, D=D, M=M, bounds=bounds)
        model = HMM(K=K,
                    D=D,
                    M=M,
                    transition=transition,
                    observation=obs,
                    transition_kwargs=transition_kwargs)

        if initialize_mu:
            if mouse == 'both':
                assert K == (n_x * n_y)**2, "K should be equal to (n_x*n_y)**2"
            else:
                assert K == n_x * n_y, "K should be equal to n_x * n_y"

            cluster_locations = np.array([((x_grids[i] + x_grids[i + 1]) / 2,
                                           (y_grids[j] + y_grids[j + 1]) / 2)
                                          for i in range(n_x)
                                          for j in range(n_y)])
            if mouse == "both":
                cluster_locations = np.array([
                    np.concatenate((loc_a, loc_b))
                    for loc_a in cluster_locations
                    for loc_b in cluster_locations
                ])

            assert cluster_locations.shape == (K, D)
            train_mu = not not_train_mu
            model.observation.transformation.mus_loc = \
                torch.tensor(cluster_locations, dtype=torch.float64, requires_grad=train_mu)
        else:
            cluster_locations = get_np(
                model.observation.transformation.mus_loc)

    # save experiment params
    exp_params = {
        "job_name": job_name,
        'downsample_n': downsample_n,
        "mouse": mouse,
        "load_model": load_model,
        "load_model_dir": load_model_dir,
        "load_opt_dir": load_opt_dir,
        "transition": transition,
        "sticky_alpha": sticky_alpha,
        "sticky_kappa": sticky_kappa,
        "K": K,
        "n_x": n_x,
        "n_y": n_y,
        "x_grids": x_grids,
        "y_grids": y_grids,
        "cluster_locations": cluster_locations,
        "not_train_mu": not_train_mu,
        "initialize_mu": initialize_mu,
        "train_model": train_model,
        "pbar_update_interval": pbar_update_interval,
        "list_of_num_iters": list_of_num_iters,
        "list_of_lr": list_of_lr,
        "video_clip_start": video_clip_start,
        "video_clip_end": video_clip_end,
        "sample_T": sample_T
    }

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/dynamic_loc/" + job_name)
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir + "/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    ##################### training ############################
    if train_model:
        initial_dir = rslt_dir + "/initial"
        if not os.path.exists(initial_dir):
            os.makedirs(initial_dir)
            print("\nCreating initial directory...")
        rslt_saving(initial_dir, model, data, mouse, sample_T, False, [],
                    quiver_scale, x_grids, y_grids)

        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters,
                                                list_of_lr)):
            losses, opt = model.fit(data,
                                    optimizer=opt,
                                    method='adam',
                                    num_iters=num_iters,
                                    lr=lr,
                                    pbar_update_interval=pbar_update_interval)
            list_of_losses.append(losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir + "/model")
            joblib.dump(opt, checkpoint_dir + "/optimizer")
            # save rest
            rslt_saving(checkpoint_dir, model, data, mouse, sample_T,
                        train_model, losses, quiver_scale, x_grids, y_grids)

    else:
        # only save the results
        rslt_saving(rslt_dir, model, data, mouse, sample_T, False, [],
                    quiver_scale, x_grids, y_grids)

    print("Finish running!")
Beispiel #14
0
def main(job_name, cuda_num, data_type, downsample_n, filter_traj, load_model,
         load_model_dir, load_opt_dir, animal, transition, sticky_alpha,
         sticky_kappa, k, x_grids, y_grids, n_x, n_y, rs, train_rs, train_vs,
         train_model, pbar_update_interval, video_clips, prop_start_end,
         held_out_proportion, torch_seed, np_seed, list_of_num_iters,
         ckpts_not_to_save, list_of_lr, list_of_k_steps, sample_t,
         quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")

    cuda_num = int(cuda_num)
    device = torch.device(
        "cuda:{}".format(cuda_num) if torch.cuda.is_available() else "cpu")
    print("Using device {} \n\n".format(device))

    assert animal in ['both', 'virgin', 'mother'], animal

    K = k
    sample_T = sample_t
    rs = float(rs) if rs else None
    video_clip_start, video_clip_end = [
        float(x) for x in video_clips.split(",")
    ]
    start, end = [float(x) for x in prop_start_end.split(",")]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    # TODO: fix for no k_steps, k > 0
    list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")
                       ] if list_of_k_steps else []
    assert len(list_of_num_iters) == len(
        list_of_lr
    ), "Length of list_of_num_iters must match length of list_of_lr."
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    ckpts_not_to_save = [int(x) for x in ckpts_not_to_save.split(',')
                         ] if ckpts_not_to_save else []

    repo = git.Repo(
        '.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    if data_type == 'full':
        data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
        traj = joblib.load(data_dir)
        traj = traj[int(36000 * video_clip_start):int(36000 * video_clip_end)]
    elif data_type == 'selected_010_virgin':
        assert animal == 'virgin', "animal much be 'virgin', but got {}.".format(
            animal)
        data_dir = repo_dir + '/SocialBehaviorptc/data/traj_010_virgin_selected'
        traj = joblib.load(data_dir)
        T = len(traj)
        traj = traj[int(T * start):int(T * end)]
    else:
        raise ValueError("unsupported data type: {}".format(data_type))

    if animal == 'virgin':
        traj = traj[:, 0:2]
    elif animal == 'mother':
        traj = traj[:, 2:4]

    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64, device=device)
    assert 0 <= held_out_proportion <= 0.4, \
        "held_out-portion should be between 0 and 0.4 (inclusive), but is {}".format(held_out_proportion)
    T = data.shape[0]
    breakpoint = int(T * (1 - held_out_proportion))
    training_data = data[:breakpoint]
    valid_data = data[breakpoint:]
    if sample_T is None:
        sample_T = training_data.shape[0]

    ######################### model ####################

    # model
    D = data.shape[1]
    assert D == 4 or D == 2, D
    M = 0

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)
        obs = model.observation

        K = model.K
        assert D == model.D, "D = {}, model.D = {}".format(D, model.D)

        n_x = len(obs.x_grids) - 1
        n_y = len(obs.y_grids) - 1
        bounds = obs.bounds

    else:
        print("Creating the model...")

        if D == 4:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX],
                               [ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX]])
        else:
            bounds = np.array([[ARENA_XMIN, ARENA_XMAX],
                               [ARENA_YMIN, ARENA_YMAX]])

        # grids
        if x_grids is None:
            x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
            x_grids = np.array(
                [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
        else:
            x_grids = np.array([float(x) for x in x_grids.split(",")])
            n_x = len(x_grids) - 1

        if y_grids is None:
            y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
            y_grids = np.array(
                [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
        else:
            y_grids = np.array([float(x) for x in y_grids.split(",")])
            n_y = len(y_grids) - 1

        mus_init = training_data[0] * torch.ones(
            K, D, dtype=torch.float64, device=device)
        if animal == 'both':
            obs = GPObservation(K=K,
                                D=D,
                                mus_init=mus_init,
                                x_grids=x_grids,
                                y_grids=y_grids,
                                bounds=bounds,
                                rs=rs,
                                train_rs=train_rs,
                                train_vs=train_vs,
                                device=device)
        else:
            obs = GPObservationSingle(K=K,
                                      D=D,
                                      mus_init=mus_init,
                                      x_grids=x_grids,
                                      y_grids=y_grids,
                                      bounds=bounds,
                                      rs=rs,
                                      train_rs=train_rs,
                                      device=device)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        elif transition == 'grid':
            transition_kwargs = dict(x_grids=x_grids, y_grids=y_grids)
        else:
            transition_kwargs = None
        model = HMM(K=K,
                    D=D,
                    M=M,
                    transition=transition,
                    observation=obs,
                    transition_kwargs=transition_kwargs,
                    device=device)

    # save experiment params
    exp_params = {
        "job_name": job_name,
        "data_type": data_type,
        'downsample_n': downsample_n,
        "filter_traj": filter_traj,
        "load_model": load_model,
        "load_model_dir": load_model_dir,
        "load_opt_dir": load_opt_dir,
        "animal": animal,
        "transition": transition,
        "sticky_alpha": sticky_alpha,
        "sticky_kappa": sticky_kappa,
        "K": K,
        "x_grids": x_grids,
        "y_grids": y_grids,
        "n_x": n_x,
        "n_y": n_y,
        "rs": rs,
        "train_rs": train_rs,
        "train_vs": train_vs,
        "train_model": train_model,
        "pbar_update_interval": pbar_update_interval,
        "video_clip_start": video_clip_start,
        "video_clip_end": video_clip_end,
        "start_percentage": start,
        "end_percentage": end,
        "held_out_proportion": held_out_proportion,
        "torch_seed": torch_seed,
        "np_seed": np_seed,
        "list_of_num_iters": list_of_num_iters,
        "list_of_lr": list_of_lr,
        "list_of_k_steps": list_of_k_steps,
        "ckpts_not_to_save": ckpts_not_to_save,
        "sample_T": sample_T,
        "quiver_scale": quiver_scale
    }

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/gp/{}/{}".format(animal, job_name))
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving exp_params to rlst_dir: ", rslt_dir)
    with open(rslt_dir + "/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memory
    if transition == "grid":
        print("Computing transition memory...")
        joint_grid_idx = model.transition.get_grid_idx(training_data[:-1])
        transition_memory_kwargs = dict(joint_grid_idx=joint_grid_idx)
        valid_joint_grid_idx = model.transition.get_grid_idx(valid_data[:-1])
        valid_data_transition_memory_kwargs = dict(
            joint_grid_idx=valid_joint_grid_idx)
    else:
        transition_memory_kwargs = None
        valid_data_transition_memory_kwargs = None

    print("Computing observation memory...")

    def get_memory_kwargs(data, train_rs):
        if data is None or data.shape[0] == 0:
            return {}
        if animal == 'both':
            kernel_distsq_xg_a = kernel_distsq(data[:-1, 0:2],
                                               obs.inducing_points)
            kernel_distsq_xg_b = kernel_distsq(data[:-1, 2:4],
                                               obs.inducing_points)

            kernel_distsq_dict = dict(kernel_distsq_xg_a=kernel_distsq_xg_a,
                                      kernel_distsq_xg_b=kernel_distsq_xg_b)
        else:
            kernel_distsq_xg = kernel_distsq(data[:-1], obs.inducing_points)

            kernel_distsq_dict = dict(kernel_distsq_xg=kernel_distsq_xg)

        if train_rs:
            return kernel_distsq_dict

        else:
            if animal == 'both':
                Sigma_a, A_a = obs.get_gp_cache(data[:-1, 0:2], 0,
                                                **kernel_distsq_dict)
                Sigma_b, A_b = obs.get_gp_cache(data[:-1, 2:4], 1,
                                                **kernel_distsq_dict)
                return dict(Sigma_a=Sigma_a, A_a=A_a, Sigma_b=Sigma_b, A_b=A_b)
            else:
                Sigma, A = obs.get_gp_cache(data[:-1], **kernel_distsq_dict)
                return dict(Sigma=Sigma, A=A)

    memory_kwargs = get_memory_kwargs(training_data, train_rs)
    valid_data_memory_kwargs = get_memory_kwargs(valid_data, train_rs)

    log_prob = model.log_likelihood(
        training_data,
        transition_memory_kwargs=transition_memory_kwargs,
        **memory_kwargs)
    print("log_prob = {}".format(log_prob))

    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters,
                                                list_of_lr)):
            training_losses, opt, valid_losses, rs_s = \
                model.fit(training_data, optimizer=opt, method='adam', num_iters=num_iters, lr=lr,
                          pbar_update_interval=pbar_update_interval, valid_data=valid_data,
                          transition_memory_kwargs=transition_memory_kwargs,
                          valid_data_transition_memory_kwargs=valid_data_transition_memory_kwargs,
                          valid_data_memory_kwargs=valid_data_memory_kwargs, **memory_kwargs)
            list_of_losses.append(training_losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir + "/model")
            joblib.dump(opt, checkpoint_dir + "/optimizer")

            # save losses
            losses = dict(training_loss=training_losses,
                          valid_loss=valid_losses)
            joblib.dump(losses, checkpoint_dir + "/losses")

            plt.figure()
            plt.plot(training_losses)
            plt.title("training loss")
            plt.savefig(checkpoint_dir + "/training_losses.jpg")
            plt.close()

            plt.figure()
            plt.plot(valid_losses)
            plt.title("validation loss")
            plt.savefig(checkpoint_dir + "/valid_losses.jpg")
            plt.close()

            rs_s = np.array(rs_s)
            plt.figure(figsize=(10, 5))
            plt.subplot(1, 2, 1)
            plt.title("lengthscale(x)")
            for k in range(K):
                plt.plot(rs_s[:, k, 0], label='k={}'.format(k))
            plt.legend()
            plt.subplot(1, 2, 2)
            plt.title("lengthscale(y)")
            for k in range(K):
                plt.plot(rs_s[:, k, 1], label='k={}'.format(k))
            plt.legend()
            plt.savefig(checkpoint_dir + "/observation_rs.jpg")
            plt.close()

            # save rest
            if i in ckpts_not_to_save:
                print("ckpt {}: skip!\n".format(i))
                continue
            with torch.no_grad():
                rslt_saving(rslt_dir=checkpoint_dir,
                            model=model,
                            data=training_data,
                            animal=animal,
                            memory_kwargs=memory_kwargs,
                            list_of_k_steps=list_of_k_steps,
                            sample_T=sample_T,
                            quiver_scale=quiver_scale,
                            valid_data=valid_data,
                            transition_memory_kwargs=transition_memory_kwargs,
                            valid_data_transition_memory_kwargs=
                            valid_data_transition_memory_kwargs,
                            valid_data_memory_kwargs=valid_data_memory_kwargs,
                            device=device)

    else:
        # only save the results
        rslt_saving(rslt_dir=rslt_dir,
                    model=model,
                    data=training_data,
                    animal=animal,
                    memory_kwargs=memory_kwargs,
                    list_of_k_steps=list_of_k_steps,
                    sample_T=sample_T,
                    quiver_scale=quiver_scale,
                    valid_data=valid_data,
                    transition_memory_kwargs=transition_memory_kwargs,
                    valid_data_transition_memory_kwargs=
                    valid_data_transition_memory_kwargs,
                    valid_data_memory_kwargs=valid_data_memory_kwargs,
                    device=device)

    print("Finish running!")
Beispiel #15
0
def main(job_name, lstm_model_type, downsample_n, filter_traj, load_model,
         load_model_dir, load_opt_dir, transition, sticky_alpha, sticky_kappa,
         acc_factor, lags, dh, dhs, k, x_grids, y_grids, n_x, n_y, train_model,
         pbar_update_interval, video_clips, torch_seed, np_seed,
         list_of_num_iters, list_of_lr, list_of_k_steps, sample_t,
         quiver_scale):
    if job_name is None:
        raise ValueError("Please provide the job name.")
    assert lstm_model_type in [
        "uni", "multi"
    ], "lstm_model_dypte must be choosen from 'uni' and 'multi'."
    K = k
    sample_T = sample_t
    video_clip_start, video_clip_end = [
        float(x) for x in video_clips.split(",")
    ]
    if dhs == 'none':
        dhs = None
    else:
        dhs = [int(d_hidden) for d_hidden in dhs.split(",")]
    list_of_num_iters = [int(x) for x in list_of_num_iters.split(",")]
    list_of_lr = [float(x) for x in list_of_lr.split(",")]
    list_of_k_steps = [int(x) for x in list_of_k_steps.split(",")]
    assert len(list_of_num_iters) == len(list_of_lr), \
        "Length of list_of_num_iters must match length of list-of_lr," \
        " but we have list_of_num_iters = {}, and list_of_lr = {}".format(list_of_num_iters, list_of_lr)
    for lr in list_of_lr:
        if lr > 1:
            raise ValueError("Learning rate should not be larger than 1!")

    repo = git.Repo(
        '.', search_parent_directories=True)  # SocialBehaviorectories=True)
    repo_dir = repo.working_tree_dir  # SocialBehavior

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    ########################## data ########################
    data_dir = repo_dir + '/SocialBehaviorptc/data/trajs_all'
    trajs = joblib.load(data_dir)

    start = int(36000 * video_clip_start)
    end = int(36000 * video_clip_end)
    traj = trajs[start:end]
    traj = downsample(traj, downsample_n)
    if filter_traj:
        traj = filter_traj_by_speed(traj, q1=0.99, q2=0.99)

    data = torch.tensor(traj, dtype=torch.float64)

    ######################### model ####################

    # model
    D = 4
    M = 0
    Df = 4

    if load_model:
        print("Loading the model from ", load_model_dir)
        model = joblib.load(load_model_dir)
        tran = model.observation.transformation
        assert isinstance(tran, (LSTMTransformation, UniLSTMTransformation)),\
            "tran should be {}, but is {}".format(LSTMTransformation, type(tran))

        K = model.K
        acc_factor = tran.acc_factor
        lstm_model_type = "uni" if isinstance(
            tran, UniLSTMTransformation) else "multi"

    else:
        print("Creating the model...")
        bounds = np.array([[ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX],
                           [ARENA_XMIN, ARENA_XMAX], [ARENA_YMIN, ARENA_YMAX]])

        if acc_factor is None:
            acc_factor = downsample_n * 10

        tran = LSTMTransformation(K=K,
                                  D=D,
                                  Df=Df,
                                  feature_vec_func=f_corner_vec_func,
                                  lags=lags,
                                  dh=dh,
                                  dhs=dhs,
                                  acc_factor=acc_factor)
        obs = ARTruncatedNormalObservation(K=K,
                                           D=D,
                                           M=M,
                                           lags=lags,
                                           bounds=bounds,
                                           transformation=tran)

        if transition == 'sticky':
            transition_kwargs = dict(alpha=sticky_alpha, kappa=sticky_kappa)
        else:
            transition_kwargs = None
        model = HMM(K=K,
                    D=D,
                    M=M,
                    transition=transition,
                    observation=obs,
                    transition_kwargs=transition_kwargs)
        model.observation.mus_init = data[0] * torch.ones(
            K, D, dtype=torch.float64)

    # grids
    if x_grids is None:
        x_grid_gap = (ARENA_XMAX - ARENA_XMIN) / n_x
        x_grids = np.array(
            [ARENA_XMIN + i * x_grid_gap for i in range(n_x + 1)])
    else:
        x_grids = np.array([float(x) for x in x_grids.split(",")])
        n_x = len(x_grids) - 1

    if y_grids is None:
        y_grid_gap = (ARENA_YMAX - ARENA_YMIN) / n_y
        y_grids = np.array(
            [ARENA_YMIN + i * y_grid_gap for i in range(n_y + 1)])
    else:
        y_grids = np.array([float(x) for x in y_grids.split(",")])
        n_y = len(y_grids) - 1

    # save experiment params
    exp_params = {
        "job_name": job_name,
        "lstm_model_tyoe": lstm_model_type,
        'downsample_n': downsample_n,
        "filter_traj": filter_traj,
        "load_model": load_model,
        "load_model_dir": load_model_dir,
        "load_opt_dir": load_opt_dir,
        "transition": transition,
        "sticky_alpha": sticky_alpha,
        "sticky_kappa": sticky_kappa,
        "acc_factor": acc_factor,
        "K": K,
        "lags": lags,
        "dh": dh,
        "dhs": dhs,
        "n_x": n_x,
        "n_y": n_y,
        "x_grids": x_grids,
        "y_grids": y_grids,
        "train_model": train_model,
        "pbar_update_interval": pbar_update_interval,
        "list_of_num_iters": list_of_num_iters,
        "list_of_lr": list_of_lr,
        "video_clip_start": video_clip_start,
        "video_clip_end": video_clip_end,
        "list_of_k_steps": list_of_k_steps,
        "sample_T": sample_T
    }

    print("Experiment params:")
    print(exp_params)

    rslt_dir = addDateTime("rslts/lstm/{}/{}".format(lstm_model_type,
                                                     job_name))
    rslt_dir = os.path.join(repo_dir, rslt_dir)
    if not os.path.exists(rslt_dir):
        os.makedirs(rslt_dir)
        print("Making result directory...")
    print("Saving to rlst_dir: ", rslt_dir)
    with open(rslt_dir + "/exp_params.json", "w") as f:
        json.dump(exp_params, f, indent=4, cls=NumpyEncoder)

    # compute memory
    print("Computing memory...")

    feature_vecs_a = f_corner_vec_func(data[:-1, 0:2])  # (T, Df, 2)
    feature_vecs_b = f_corner_vec_func(data[:-1, 2:4])  # (T, Df, 2)
    feature_vecs = (feature_vecs_a, feature_vecs_b)

    packed_data = get_packed_data(data[:-1], lags=lags)

    memory_kwargs = dict(packed_data=packed_data, feature_vecs=feature_vecs)

    ##################### training ############################
    if train_model:
        print("start training")
        list_of_losses = []
        if load_opt_dir != "":
            opt = joblib.load(load_opt_dir)
        else:
            opt = None
        for i, (num_iters, lr) in enumerate(zip(list_of_num_iters,
                                                list_of_lr)):
            losses, opt = model.fit(data,
                                    optimizer=opt,
                                    method='adam',
                                    num_iters=num_iters,
                                    lr=lr,
                                    pbar_update_interval=pbar_update_interval,
                                    **memory_kwargs)
            list_of_losses.append(losses)

            checkpoint_dir = rslt_dir + "/checkpoint_{}".format(i)

            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
                print("Creating checkpoint_{} directory...".format(i))
            # save model and opt
            joblib.dump(model, checkpoint_dir + "/model")
            joblib.dump(opt, checkpoint_dir + "/optimizer")
            # save rest
            rslt_saving(checkpoint_dir,
                        model,
                        data,
                        memory_kwargs,
                        list_of_k_steps,
                        sample_T,
                        train_model,
                        losses,
                        quiver_scale,
                        x_grids=x_grids,
                        y_grids=y_grids)

    else:
        # only save the results
        rslt_saving(rslt_dir,
                    model,
                    data,
                    memory_kwargs,
                    list_of_k_steps,
                    sample_T,
                    False, [],
                    quiver_scale,
                    x_grids=x_grids,
                    y_grids=y_grids)

    print("Finish running!")