コード例 #1
0
ファイル: main.py プロジェクト: jik0730/tendon-driven-system
def main():
    # During some time steps...
    # 1. Compute theta_hat by f_TRUE (1)
    # 2. Compute theta_hat by f_EST (2)
    # 3. Compute loss between (1) and (2)
    # 4. Optimize parameters of f_EST
    print('start exp: {} and {}'.format(args.model_dir, args.data_type))

    # Simulation parameters and intial values
    const['del_t'] = float(1 / args.freq)  # sampling time for dynamics
    T = args.freq * args.simT  # number of operation for dynamics

    # Target trajectory
    if 'sine' in args.data_type and 'Hz' in args.data_type:
        target_traj = sin_target_traj(
            args.freq, args.simT, sine_type=args.data_type)
    elif 'random_walk' in args.data_type:
        target_traj = random_walk(T, args.data_type)
    elif 'sine_freq_variation' == args.data_type:
        freq_from = 0.5
        freq_to = 10.
        sys_freq = args.freq
        simT = args.simT
        sine_type = 'sine_1Hz_10deg_0offset'
        target_traj = sin_freq_variation(freq_from, freq_to, sys_freq, simT,
                                         sine_type)
    elif 'sine_freq_variation_with_step' == args.data_type:
        freq_from = 0.5
        freq_to = 10.
        sys_freq = args.freq
        simT = args.simT
        sine_type = 'sine_1Hz_10deg_0offset'
        target_traj = sin_freq_variation_with_step(freq_from, freq_to,
                                                   sys_freq, simT, sine_type)
    elif 'step' in args.data_type:
        target_traj = step_target_traj(T, args.data_type)
    else:
        raise Exception('I dont know your targets')

    # initiate values
    t_OBS_vals = [
        torch.FloatTensor([target_traj[0]]),
        torch.FloatTensor([target_traj[1]]),
        torch.FloatTensor([target_traj[2]]),
        torch.FloatTensor([target_traj[3]])
    ]
    f1_EST_vals = [torch.zeros(1)]
    F_EST = torch.zeros(1)

    # NOTE 0 if f_est=0, 1 if f_est is oracle, 2 if f_est is MLP
    friction_type = args.ftype

    # for plotting
    time_stamp = []
    target_history = []
    obs_history = []
    est_history = []
    f1_obs_history = []
    f1_est_history = []
    F_est_history = []

    # Define PID controller
    Kp = args.Kp
    Kd = args.Kd * Kp
    Ki = args.Ki * Kp
    pid_cont = PIDController(p=Kp, i=Ki, d=Kd, del_t=const['del_t'] * 10)

    # Define models TODO for now we ignore f2.
    f1_OBS_fn = RealFriction(const)
    f1_EST_fn = Friction_EST(args.hdim)

    # Define loss_fn, optimizer
    optimizer = torch.optim.Adam(f1_EST_fn.parameters(), lr=args.lr)
    loss_fn = nn.MSELoss()

    for t in range(4, T):
        # current target
        target = target_traj[t]

        # detach nodes for simplicity
        f1 = f1_EST_vals[-1].detach()

        # compute input force (F) at t-2
        if t % 10 == 3:
            const['T_d'] = pid_cont.compute_torque(target, t_OBS_vals[-1])
            F_EST = compute_input_force(const, t_OBS_vals[-1], f1)

        # compute frictions (f) at t-2
        t_dot_OBS = (t_OBS_vals[-2] - t_OBS_vals[-3]) / const['del_t']
        f1_OBS = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST)

        if friction_type == 0:
            f1_EST = torch.zeros(1)
        elif friction_type == 1:
            f1_EST = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST)
        elif friction_type == 2:
            f1_EST = f1_EST_fn(torch.cat([t_OBS_vals[-2], t_dot_OBS, F_EST]))
        else:
            raise NotImplementedError()

        # compute theta_hat (t) at t
        t_OBS = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2],
                                  t_OBS_vals[-3], F_EST, f1_OBS)
        t_EST = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2],
                                  t_OBS_vals[-3], F_EST, f1_EST)

        # Optimization
        # print(t, t_OBS, f1_OBS, F_EST, t_dot_OBS, target)
        if friction_type == 2:
            loss = loss_fn(t_EST, t_OBS)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # print('loss={} at t={}'.format(loss.item(), t))

        # store values to containers
        t_OBS_vals[-4] = t_OBS_vals[-3]
        t_OBS_vals[-3] = t_OBS_vals[-2]
        t_OBS_vals[-2] = t_OBS_vals[-1]
        t_OBS_vals[-1] = t_OBS
        f1_EST_vals[-1] = f1_EST

        # store history for plotting
        time_stamp.append(float(t / args.freq))
        target_history.append(float(target.numpy()))
        obs_history.append(float(t_OBS.numpy()))
        est_history.append(float(t_EST.detach().numpy()))
        f1_obs_history.append(float(f1_OBS.detach().numpy()))
        f1_est_history.append(float(f1_EST.detach().numpy()))
        F_est_history.append(float(F_EST.detach().numpy()))

        # for debugging
        if np.isnan(t_OBS.numpy()):
            break

    # store hyper-parameters and settings
    params_dir = os.path.join(args.model_dir, args.data_type)
    if not os.path.exists(params_dir):
        os.makedirs(params_dir)
    params = OrderedDict(vars(args))
    const_ord = OrderedDict(cast_dict_to_float(const))
    with open(os.path.join(params_dir, 'params.json'), 'w') as f:
        json.dump(params, f)
    with open(os.path.join(params_dir, 'const.json'), 'w') as f:
        json.dump(const_ord, f)

    # store values for post-visualization
    if friction_type == 0:
        training_log_dir = os.path.join(params_dir, 'f_est=0', 'training')
    elif friction_type == 1:
        training_log_dir = os.path.join(params_dir, 'oracle', 'training')
    elif friction_type == 2:
        training_log_dir = os.path.join(params_dir, 'MLP', 'training')
    if not os.path.exists(training_log_dir):
        os.makedirs(training_log_dir)
    store_logs(time_stamp, target_history, obs_history, est_history,
               f1_obs_history, f1_est_history, F_est_history, training_log_dir)

    # save trained model
    if friction_type == 2:
        model_name = os.path.join(params_dir, 'MLP', 'f1_model')
        torch.save(f1_EST_fn.state_dict(), model_name)

    # visualize
    plot_theta(time_stamp, target_history, obs_history, est_history,
               training_log_dir)
コード例 #2
0
def evaluate(const, params, ftype):
    print('Start evaluation exp: {} and {} and {}'.format(
        args.model_dir, args.data_type, args.ftype))

    # Total running steps
    T = params['freq'] * params['simT']

    # Target trajectory for evaluation
    if 'sine' in args.eval_type and 'Hz' in args.eval_type:
        target_traj = sin_target_traj(params['freq'],
                                      params['simT'],
                                      sine_type=args.eval_type)
    elif 'random_walk' in args.eval_type:
        target_traj = random_walk(T, args.eval_type)
    elif 'sine_freq_variation' == args.eval_type:
        freq_from = 0.5
        freq_to = 10.
        sys_freq = params['freq']
        simT = params['simT']
        sine_type = 'sine_1Hz_10deg_0offset'
        target_traj = sin_freq_variation(freq_from, freq_to, sys_freq, simT,
                                         sine_type)
    elif 'sine_freq_variation_with_step' == args.eval_type:
        freq_from = 0.5
        freq_to = 10.
        sys_freq = params['freq']
        simT = params['simT']
        sine_type = 'sine_1Hz_10deg_0offset'
        target_traj = sin_freq_variation_with_step(freq_from, freq_to,
                                                   sys_freq, simT, sine_type)
    elif 'step' in args.eval_type:
        target_traj = step_target_traj(T, args.eval_type)
    else:
        raise Exception('I dont know your targets')

    # initiate values
    if 'step' in args.eval_type:
        t_OBS_vals = [
            torch.FloatTensor([0]),
            torch.FloatTensor([0]),
            torch.FloatTensor([0]),
            torch.FloatTensor([0])
        ]
    else:
        t_OBS_vals = [
            torch.FloatTensor([target_traj[0]]),
            torch.FloatTensor([target_traj[1]]),
            torch.FloatTensor([target_traj[2]]),
            torch.FloatTensor([target_traj[3]])
        ]
    f1_EST_vals = [torch.zeros(1)]
    F_EST = torch.zeros(1)

    # for plotting
    time_stamp = []
    target_history = []
    obs_history = []
    est_history = []
    f1_obs_history = []
    f1_est_history = []
    F_est_history = []

    # Define PID controller
    Kp = params['Kp']
    Kd = params['Kd'] * Kp
    Ki = params['Ki'] * Kp
    pid_cont = PIDController(p=Kp, i=Ki, d=Kd, del_t=const['del_t'] * 10)

    # Define models TODO for now we ignore f2.
    f1_OBS_fn = RealFriction(const)
    f1_EST_fn = Friction_EST(params['hdim'])
    if ftype == 2:
        state_dict_path = os.path.join(args.model_dir, args.data_type, 'MLP',
                                       'f1_model')
        state_dict = torch.load(state_dict_path)
        f1_EST_fn.load_state_dict(state_dict)

    for t in range(4, T):
        # current target
        target = target_traj[t]

        # detach nodes for simplicity
        f1 = f1_EST_vals[-1].detach()

        # compute input force (F) at t-2
        if t % 10 == 3:
            const['T_d'] = pid_cont.compute_torque(target, t_OBS_vals[-1])
            F_EST = compute_input_force(const, t_OBS_vals[-1], f1)

        # compute frictions (f) at t-2
        t_dot_OBS = (t_OBS_vals[-2] - t_OBS_vals[-3]) / const['del_t']
        f1_OBS = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST)

        if ftype == 0:
            f1_EST = torch.zeros(1)
        elif ftype == 1:
            f1_EST = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST)
        elif ftype == 2:
            f1_EST = f1_EST_fn(torch.cat([t_OBS_vals[-2], t_dot_OBS, F_EST]))
        else:
            raise NotImplementedError()

        # compute theta_hat (t) at t
        t_OBS = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2],
                                  t_OBS_vals[-3], F_EST, f1_OBS)
        t_EST = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2],
                                  t_OBS_vals[-3], F_EST, f1_EST)

        # store values to containers
        t_OBS_vals[-4] = t_OBS_vals[-3]
        t_OBS_vals[-3] = t_OBS_vals[-2]
        t_OBS_vals[-2] = t_OBS_vals[-1]
        t_OBS_vals[-1] = t_OBS
        f1_EST_vals[-1] = f1_EST

        # store history for plotting
        time_stamp.append(float(t / params['freq']))
        target_history.append(float(target.numpy()))
        obs_history.append(float(t_OBS.numpy()))
        est_history.append(float(t_EST.detach().numpy()))
        f1_obs_history.append(float(f1_OBS.detach().numpy()))
        f1_est_history.append(float(f1_EST.detach().numpy()))
        F_est_history.append(float(F_EST.detach().numpy()))

        # for debugging
        # if np.isnan(t_OBS.numpy()):
        #     break

    # store values for post-visualization
    params_dir = os.path.join(args.model_dir, args.data_type)
    if ftype == 0:
        eval_log_dir = os.path.join(params_dir, 'f_est=0', 'evaluation',
                                    args.eval_type)
    elif ftype == 1:
        eval_log_dir = os.path.join(params_dir, 'oracle', 'evaluation',
                                    args.eval_type)
    elif ftype == 2:
        eval_log_dir = os.path.join(params_dir, 'MLP', 'evaluation',
                                    args.eval_type)
    if not os.path.exists(eval_log_dir):
        os.makedirs(eval_log_dir)
    store_logs(time_stamp, target_history, obs_history, est_history,
               f1_obs_history, f1_est_history, F_est_history, eval_log_dir)

    # visualize
    plot_theta(time_stamp, target_history, obs_history, est_history,
               eval_log_dir)