def main(): # During some time steps... # 1. Compute theta_hat by f_TRUE (1) # 2. Compute theta_hat by f_EST (2) # 3. Compute loss between (1) and (2) # 4. Optimize parameters of f_EST print('start exp: {} and {}'.format(args.model_dir, args.data_type)) # Simulation parameters and intial values const['del_t'] = float(1 / args.freq) # sampling time for dynamics T = args.freq * args.simT # number of operation for dynamics # Target trajectory if 'sine' in args.data_type and 'Hz' in args.data_type: target_traj = sin_target_traj( args.freq, args.simT, sine_type=args.data_type) elif 'random_walk' in args.data_type: target_traj = random_walk(T, args.data_type) elif 'sine_freq_variation' == args.data_type: freq_from = 0.5 freq_to = 10. sys_freq = args.freq simT = args.simT sine_type = 'sine_1Hz_10deg_0offset' target_traj = sin_freq_variation(freq_from, freq_to, sys_freq, simT, sine_type) elif 'sine_freq_variation_with_step' == args.data_type: freq_from = 0.5 freq_to = 10. sys_freq = args.freq simT = args.simT sine_type = 'sine_1Hz_10deg_0offset' target_traj = sin_freq_variation_with_step(freq_from, freq_to, sys_freq, simT, sine_type) elif 'step' in args.data_type: target_traj = step_target_traj(T, args.data_type) else: raise Exception('I dont know your targets') # initiate values t_OBS_vals = [ torch.FloatTensor([target_traj[0]]), torch.FloatTensor([target_traj[1]]), torch.FloatTensor([target_traj[2]]), torch.FloatTensor([target_traj[3]]) ] f1_EST_vals = [torch.zeros(1)] F_EST = torch.zeros(1) # NOTE 0 if f_est=0, 1 if f_est is oracle, 2 if f_est is MLP friction_type = args.ftype # for plotting time_stamp = [] target_history = [] obs_history = [] est_history = [] f1_obs_history = [] f1_est_history = [] F_est_history = [] # Define PID controller Kp = args.Kp Kd = args.Kd * Kp Ki = args.Ki * Kp pid_cont = PIDController(p=Kp, i=Ki, d=Kd, del_t=const['del_t'] * 10) # Define models TODO for now we ignore f2. f1_OBS_fn = RealFriction(const) f1_EST_fn = Friction_EST(args.hdim) # Define loss_fn, optimizer optimizer = torch.optim.Adam(f1_EST_fn.parameters(), lr=args.lr) loss_fn = nn.MSELoss() for t in range(4, T): # current target target = target_traj[t] # detach nodes for simplicity f1 = f1_EST_vals[-1].detach() # compute input force (F) at t-2 if t % 10 == 3: const['T_d'] = pid_cont.compute_torque(target, t_OBS_vals[-1]) F_EST = compute_input_force(const, t_OBS_vals[-1], f1) # compute frictions (f) at t-2 t_dot_OBS = (t_OBS_vals[-2] - t_OBS_vals[-3]) / const['del_t'] f1_OBS = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST) if friction_type == 0: f1_EST = torch.zeros(1) elif friction_type == 1: f1_EST = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST) elif friction_type == 2: f1_EST = f1_EST_fn(torch.cat([t_OBS_vals[-2], t_dot_OBS, F_EST])) else: raise NotImplementedError() # compute theta_hat (t) at t t_OBS = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2], t_OBS_vals[-3], F_EST, f1_OBS) t_EST = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2], t_OBS_vals[-3], F_EST, f1_EST) # Optimization # print(t, t_OBS, f1_OBS, F_EST, t_dot_OBS, target) if friction_type == 2: loss = loss_fn(t_EST, t_OBS) optimizer.zero_grad() loss.backward() optimizer.step() # print('loss={} at t={}'.format(loss.item(), t)) # store values to containers t_OBS_vals[-4] = t_OBS_vals[-3] t_OBS_vals[-3] = t_OBS_vals[-2] t_OBS_vals[-2] = t_OBS_vals[-1] t_OBS_vals[-1] = t_OBS f1_EST_vals[-1] = f1_EST # store history for plotting time_stamp.append(float(t / args.freq)) target_history.append(float(target.numpy())) obs_history.append(float(t_OBS.numpy())) est_history.append(float(t_EST.detach().numpy())) f1_obs_history.append(float(f1_OBS.detach().numpy())) f1_est_history.append(float(f1_EST.detach().numpy())) F_est_history.append(float(F_EST.detach().numpy())) # for debugging if np.isnan(t_OBS.numpy()): break # store hyper-parameters and settings params_dir = os.path.join(args.model_dir, args.data_type) if not os.path.exists(params_dir): os.makedirs(params_dir) params = OrderedDict(vars(args)) const_ord = OrderedDict(cast_dict_to_float(const)) with open(os.path.join(params_dir, 'params.json'), 'w') as f: json.dump(params, f) with open(os.path.join(params_dir, 'const.json'), 'w') as f: json.dump(const_ord, f) # store values for post-visualization if friction_type == 0: training_log_dir = os.path.join(params_dir, 'f_est=0', 'training') elif friction_type == 1: training_log_dir = os.path.join(params_dir, 'oracle', 'training') elif friction_type == 2: training_log_dir = os.path.join(params_dir, 'MLP', 'training') if not os.path.exists(training_log_dir): os.makedirs(training_log_dir) store_logs(time_stamp, target_history, obs_history, est_history, f1_obs_history, f1_est_history, F_est_history, training_log_dir) # save trained model if friction_type == 2: model_name = os.path.join(params_dir, 'MLP', 'f1_model') torch.save(f1_EST_fn.state_dict(), model_name) # visualize plot_theta(time_stamp, target_history, obs_history, est_history, training_log_dir)
def evaluate(const, params, ftype): print('Start evaluation exp: {} and {} and {}'.format( args.model_dir, args.data_type, args.ftype)) # Total running steps T = params['freq'] * params['simT'] # Target trajectory for evaluation if 'sine' in args.eval_type and 'Hz' in args.eval_type: target_traj = sin_target_traj(params['freq'], params['simT'], sine_type=args.eval_type) elif 'random_walk' in args.eval_type: target_traj = random_walk(T, args.eval_type) elif 'sine_freq_variation' == args.eval_type: freq_from = 0.5 freq_to = 10. sys_freq = params['freq'] simT = params['simT'] sine_type = 'sine_1Hz_10deg_0offset' target_traj = sin_freq_variation(freq_from, freq_to, sys_freq, simT, sine_type) elif 'sine_freq_variation_with_step' == args.eval_type: freq_from = 0.5 freq_to = 10. sys_freq = params['freq'] simT = params['simT'] sine_type = 'sine_1Hz_10deg_0offset' target_traj = sin_freq_variation_with_step(freq_from, freq_to, sys_freq, simT, sine_type) elif 'step' in args.eval_type: target_traj = step_target_traj(T, args.eval_type) else: raise Exception('I dont know your targets') # initiate values if 'step' in args.eval_type: t_OBS_vals = [ torch.FloatTensor([0]), torch.FloatTensor([0]), torch.FloatTensor([0]), torch.FloatTensor([0]) ] else: t_OBS_vals = [ torch.FloatTensor([target_traj[0]]), torch.FloatTensor([target_traj[1]]), torch.FloatTensor([target_traj[2]]), torch.FloatTensor([target_traj[3]]) ] f1_EST_vals = [torch.zeros(1)] F_EST = torch.zeros(1) # for plotting time_stamp = [] target_history = [] obs_history = [] est_history = [] f1_obs_history = [] f1_est_history = [] F_est_history = [] # Define PID controller Kp = params['Kp'] Kd = params['Kd'] * Kp Ki = params['Ki'] * Kp pid_cont = PIDController(p=Kp, i=Ki, d=Kd, del_t=const['del_t'] * 10) # Define models TODO for now we ignore f2. f1_OBS_fn = RealFriction(const) f1_EST_fn = Friction_EST(params['hdim']) if ftype == 2: state_dict_path = os.path.join(args.model_dir, args.data_type, 'MLP', 'f1_model') state_dict = torch.load(state_dict_path) f1_EST_fn.load_state_dict(state_dict) for t in range(4, T): # current target target = target_traj[t] # detach nodes for simplicity f1 = f1_EST_vals[-1].detach() # compute input force (F) at t-2 if t % 10 == 3: const['T_d'] = pid_cont.compute_torque(target, t_OBS_vals[-1]) F_EST = compute_input_force(const, t_OBS_vals[-1], f1) # compute frictions (f) at t-2 t_dot_OBS = (t_OBS_vals[-2] - t_OBS_vals[-3]) / const['del_t'] f1_OBS = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST) if ftype == 0: f1_EST = torch.zeros(1) elif ftype == 1: f1_EST = f1_OBS_fn(t_OBS_vals[-2], t_OBS_vals[-3], F_EST) elif ftype == 2: f1_EST = f1_EST_fn(torch.cat([t_OBS_vals[-2], t_dot_OBS, F_EST])) else: raise NotImplementedError() # compute theta_hat (t) at t t_OBS = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2], t_OBS_vals[-3], F_EST, f1_OBS) t_EST = compute_theta_hat(const, t_OBS_vals[-1], t_OBS_vals[-2], t_OBS_vals[-3], F_EST, f1_EST) # store values to containers t_OBS_vals[-4] = t_OBS_vals[-3] t_OBS_vals[-3] = t_OBS_vals[-2] t_OBS_vals[-2] = t_OBS_vals[-1] t_OBS_vals[-1] = t_OBS f1_EST_vals[-1] = f1_EST # store history for plotting time_stamp.append(float(t / params['freq'])) target_history.append(float(target.numpy())) obs_history.append(float(t_OBS.numpy())) est_history.append(float(t_EST.detach().numpy())) f1_obs_history.append(float(f1_OBS.detach().numpy())) f1_est_history.append(float(f1_EST.detach().numpy())) F_est_history.append(float(F_EST.detach().numpy())) # for debugging # if np.isnan(t_OBS.numpy()): # break # store values for post-visualization params_dir = os.path.join(args.model_dir, args.data_type) if ftype == 0: eval_log_dir = os.path.join(params_dir, 'f_est=0', 'evaluation', args.eval_type) elif ftype == 1: eval_log_dir = os.path.join(params_dir, 'oracle', 'evaluation', args.eval_type) elif ftype == 2: eval_log_dir = os.path.join(params_dir, 'MLP', 'evaluation', args.eval_type) if not os.path.exists(eval_log_dir): os.makedirs(eval_log_dir) store_logs(time_stamp, target_history, obs_history, est_history, f1_obs_history, f1_est_history, F_est_history, eval_log_dir) # visualize plot_theta(time_stamp, target_history, obs_history, est_history, eval_log_dir)