def inference_demo(screen): forces = [hor_impulse] ground_truth_mass = Variable(torch.DoubleTensor([7])) world, c = make_world(forces, ground_truth_mass, num_links=NUM_LINKS) rec = None # rec = Recorder(DT, screen) ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec) ground_truth_pos = [p.data for p in ground_truth_pos] ground_truth_pos = Variable(torch.cat(ground_truth_pos)) learning_rate = 0.01 max_iter = 100 next_mass = Variable(torch.DoubleTensor([1.3]), requires_grad=True) loss_hist = [] mass_hist = [next_mass] last_dist = 1e10 for i in range(max_iter): world, c = make_world(forces, next_mass, num_links=NUM_LINKS) # world.load_state(initial_state) # world.reset_engine() positions = positions_run_world(world, run_time=10, screen=None) positions = torch.cat(positions) positions = positions[:len(ground_truth_pos)] # temp_ground_truth_pos = ground_truth_pos[:len(positions)] loss = MSELoss()(positions, ground_truth_pos) loss.backward() grad = c.mass.grad.data # clip gradient grad = torch.max(torch.min(grad, torch.DoubleTensor([100])), torch.DoubleTensor([-100])) temp = c.mass.data - learning_rate * grad temp = max(MASS_EPS, temp[0]) next_mass = Variable(torch.DoubleTensor([temp]), requires_grad=True) # learning_rate /= 1.1 print(i, '/', max_iter, loss.data[0]) print(grad) print(next_mass) # print(learned_force(0.05)) if abs((last_dist - loss).data[0]) < 1e-3: break last_dist = loss loss_hist.append(loss) mass_hist.append(next_mass) world = make_world(forces, next_mass, num_links=NUM_LINKS)[0] # world.load_state(initial_state) # world.reset_engine() rec = None # rec = Recorder(DT, screen) positions_run_world(world, run_time=10, screen=screen, recorder=rec) loss = MSELoss()(positions, ground_truth_pos) print(loss.data[0]) print(next_mass) plot(loss_hist) plot(mass_hist)
def grad_demo(screen): initial_force = torch.DoubleTensor([0, 3, 0]) initial_force[2] = 0 initial_force = Variable(initial_force, requires_grad=True) # Initial demo learned_force = lambda t: initial_force if t < 0.1 else ExternalForce.ZEROS # learned_force = gravity world, c, target = make_world(learned_force) # initial_state = world.save_state() # next_fric_coeff = Variable(torch.DoubleTensor([1e-7]), requires_grad=True) # c.fric_coeff = next_fric_coeff # initial_state = world.save_state() run_world(world, run_time=TIME, screen=screen) learning_rate = 0.001 max_iter = 100 dist_hist = [] last_dist = 1e10 for i in range(max_iter): learned_force = lambda t: initial_force if t < 0.1 else ExternalForce.ZEROS world, c, target = make_world(learned_force) # world.load_state(initial_state) # world.reset_engine() # c = world.bodies[0] # c.fric_coeff = next_fric_coeff run_world(world, run_time=TIME, screen=None) dist = (target.pos - c.pos).norm() dist.backward() grad = initial_force.grad.data # grad.clamp_(-10, 10) initial_force = Variable(initial_force.data - learning_rate * grad, requires_grad=True) # grad = c.fric_coeff.grad.data # grad.clamp_(-10, 10) # temp = c.fric_coeff.data - learning_rate * grad # temp.clamp_(1e-7, 1) learning_rate /= 1.1 # next_fric_coeff = Variable(temp, requires_grad=True) print(i, '/', max_iter, dist.data[0]) print(grad) # print(next_fric_coeff) print(learned_force(0.05)) print('=======') if abs((last_dist - dist).data[0]) < 1e-5: break last_dist = dist dist_hist.append(dist) world = make_world(learned_force)[0] # c.fric_coeff = next_fric_coeff # world.load_state(initial_state) # world.reset_engine() rec = None # rec = Recorder(DT, screen) run_world(world, run_time=TIME, screen=screen, recorder=rec) dist = (target.pos - c.pos).norm() print(dist.data[0]) # import pickle # with open('control_balls_dist_hist.pkl', 'w') as f: # pickle.dump(dist_hist, f) plot(dist_hist)
def main(screen): forces = [hor_impulse] ground_truth_mass = torch.tensor([TOTAL_MASS], dtype=DTYPE) world, chain = make_world(forces, ground_truth_mass, num_links=NUM_LINKS) rec = None # rec = Recorder(DT, screen) ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec) ground_truth_pos = [p.data for p in ground_truth_pos] ground_truth_pos = torch.cat(ground_truth_pos) learning_rate = 0.5 max_iter = 100 next_mass = torch.rand_like(ground_truth_mass, requires_grad=True) print('\rInitial mass:', next_mass.item()) print('-----') optim = torch.optim.RMSprop([next_mass], lr=learning_rate) loss_hist = [] mass_hist = [next_mass.item()] last_loss = 1e10 for i in range(max_iter): if i % 1 == 0: world, chain = make_world(forces, next_mass.clone().detach(), num_links=NUM_LINKS) run_world(world, run_time=10, print_time=False, screen=None, recorder=None) world, chain = make_world(forces, next_mass, num_links=NUM_LINKS) positions = positions_run_world(world, run_time=10, screen=None) positions = torch.cat(positions) positions = positions[:len(ground_truth_pos)] clipped_ground_truth_pos = ground_truth_pos[:len(positions)] optim.zero_grad() loss = MSELoss()(positions, clipped_ground_truth_pos) loss.backward() optim.step() print('Iteration: {} / {}'.format(i+1, max_iter)) print('Loss:', loss.item()) print('Gradient:', next_mass.grad.item()) print('Next mass:', next_mass.item()) print('-----') if abs((last_loss - loss).item()) < STOP_DIFF: print('Loss changed by less than {} between iterations, stopping training.' .format(STOP_DIFF)) break last_loss = loss loss_hist.append(loss.item()) mass_hist.append(next_mass.item()) world = make_world(forces, next_mass, num_links=NUM_LINKS)[0] rec = None positions = positions_run_world(world, run_time=10, screen=screen, recorder=rec) positions = torch.cat(positions) positions = positions[:len(ground_truth_pos)] clipped_ground_truth_pos = ground_truth_pos[:len(positions)] loss = MSELoss()(positions, clipped_ground_truth_pos) print('Final loss:', loss.item()) print('Final mass:', next_mass.item()) plot(loss_hist) plot(mass_hist)