Esempio n. 1
0
def inference_demo(screen):
    forces = [hor_impulse]
    ground_truth_mass = Variable(torch.DoubleTensor([7]))
    world, c = make_world(forces, ground_truth_mass, num_links=NUM_LINKS)

    rec = None
    # rec = Recorder(DT, screen)
    ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    ground_truth_pos = [p.data for p in ground_truth_pos]
    ground_truth_pos = Variable(torch.cat(ground_truth_pos))

    learning_rate = 0.01
    max_iter = 100

    next_mass = Variable(torch.DoubleTensor([1.3]), requires_grad=True)
    loss_hist = []
    mass_hist = [next_mass]
    last_dist = 1e10
    for i in range(max_iter):
        world, c = make_world(forces, next_mass, num_links=NUM_LINKS)
        # world.load_state(initial_state)
        # world.reset_engine()
        positions = positions_run_world(world, run_time=10, screen=None)
        positions = torch.cat(positions)
        positions = positions[:len(ground_truth_pos)]
        # temp_ground_truth_pos = ground_truth_pos[:len(positions)]

        loss = MSELoss()(positions, ground_truth_pos)
        loss.backward()
        grad = c.mass.grad.data
        # clip gradient
        grad = torch.max(torch.min(grad, torch.DoubleTensor([100])), torch.DoubleTensor([-100]))
        temp = c.mass.data - learning_rate * grad
        temp = max(MASS_EPS, temp[0])
        next_mass = Variable(torch.DoubleTensor([temp]), requires_grad=True)
        # learning_rate /= 1.1
        print(i, '/', max_iter, loss.data[0])
        print(grad)
        print(next_mass)
        # print(learned_force(0.05))
        if abs((last_dist - loss).data[0]) < 1e-3:
            break
        last_dist = loss
        loss_hist.append(loss)
        mass_hist.append(next_mass)

    world = make_world(forces, next_mass, num_links=NUM_LINKS)[0]
    # world.load_state(initial_state)
    # world.reset_engine()
    rec = None
    # rec = Recorder(DT, screen)
    positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    loss = MSELoss()(positions, ground_truth_pos)
    print(loss.data[0])
    print(next_mass)

    plot(loss_hist)
    plot(mass_hist)
Esempio n. 2
0
def grad_demo(screen):
    initial_force = torch.DoubleTensor([0, 3, 0])
    initial_force[2] = 0
    initial_force = Variable(initial_force, requires_grad=True)

    # Initial demo
    learned_force = lambda t: initial_force if t < 0.1 else ExternalForce.ZEROS
    # learned_force = gravity
    world, c, target = make_world(learned_force)
    # initial_state = world.save_state()
    # next_fric_coeff = Variable(torch.DoubleTensor([1e-7]), requires_grad=True)
    # c.fric_coeff = next_fric_coeff
    # initial_state = world.save_state()
    run_world(world, run_time=TIME, screen=screen)

    learning_rate = 0.001
    max_iter = 100

    dist_hist = []
    last_dist = 1e10
    for i in range(max_iter):
        learned_force = lambda t: initial_force if t < 0.1 else ExternalForce.ZEROS

        world, c, target = make_world(learned_force)
        # world.load_state(initial_state)
        # world.reset_engine()
        # c = world.bodies[0]
        # c.fric_coeff = next_fric_coeff
        run_world(world, run_time=TIME, screen=None)

        dist = (target.pos - c.pos).norm()
        dist.backward()
        grad = initial_force.grad.data
        # grad.clamp_(-10, 10)
        initial_force = Variable(initial_force.data - learning_rate * grad,
                                 requires_grad=True)
        # grad = c.fric_coeff.grad.data
        # grad.clamp_(-10, 10)
        # temp = c.fric_coeff.data - learning_rate * grad
        # temp.clamp_(1e-7, 1)
        learning_rate /= 1.1
        # next_fric_coeff = Variable(temp, requires_grad=True)
        print(i, '/', max_iter, dist.data[0])
        print(grad)
        # print(next_fric_coeff)
        print(learned_force(0.05))
        print('=======')
        if abs((last_dist - dist).data[0]) < 1e-5:
            break
        last_dist = dist
        dist_hist.append(dist)

    world = make_world(learned_force)[0]
    # c.fric_coeff = next_fric_coeff
    # world.load_state(initial_state)
    # world.reset_engine()
    rec = None
    # rec = Recorder(DT, screen)
    run_world(world, run_time=TIME, screen=screen, recorder=rec)
    dist = (target.pos - c.pos).norm()
    print(dist.data[0])

    # import pickle
    # with open('control_balls_dist_hist.pkl', 'w') as f:
    #     pickle.dump(dist_hist, f)
    plot(dist_hist)
Esempio n. 3
0
def main(screen):
    forces = [hor_impulse]
    ground_truth_mass = torch.tensor([TOTAL_MASS], dtype=DTYPE)
    world, chain = make_world(forces, ground_truth_mass, num_links=NUM_LINKS)

    rec = None
    # rec = Recorder(DT, screen)
    ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    ground_truth_pos = [p.data for p in ground_truth_pos]
    ground_truth_pos = torch.cat(ground_truth_pos)

    learning_rate = 0.5
    max_iter = 100

    next_mass = torch.rand_like(ground_truth_mass, requires_grad=True)
    print('\rInitial mass:', next_mass.item())
    print('-----')

    optim = torch.optim.RMSprop([next_mass], lr=learning_rate)
    loss_hist = []
    mass_hist = [next_mass.item()]
    last_loss = 1e10
    for i in range(max_iter):
        if i % 1 == 0:
            world, chain = make_world(forces, next_mass.clone().detach(), num_links=NUM_LINKS)
            run_world(world, run_time=10, print_time=False, screen=None, recorder=None)

        world, chain = make_world(forces, next_mass, num_links=NUM_LINKS)
        positions = positions_run_world(world, run_time=10, screen=None)
        positions = torch.cat(positions)
        positions = positions[:len(ground_truth_pos)]
        clipped_ground_truth_pos = ground_truth_pos[:len(positions)]

        optim.zero_grad()
        loss = MSELoss()(positions, clipped_ground_truth_pos)
        loss.backward()

        optim.step()

        print('Iteration: {} / {}'.format(i+1, max_iter))
        print('Loss:', loss.item())
        print('Gradient:', next_mass.grad.item())
        print('Next mass:', next_mass.item())
        print('-----')
        if abs((last_loss - loss).item()) < STOP_DIFF:
            print('Loss changed by less than {} between iterations, stopping training.'
                  .format(STOP_DIFF))
            break
        last_loss = loss
        loss_hist.append(loss.item())
        mass_hist.append(next_mass.item())

    world = make_world(forces, next_mass, num_links=NUM_LINKS)[0]
    rec = None
    positions = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    positions = torch.cat(positions)
    positions = positions[:len(ground_truth_pos)]
    clipped_ground_truth_pos = ground_truth_pos[:len(positions)]
    loss = MSELoss()(positions, clipped_ground_truth_pos)
    print('Final loss:', loss.item())
    print('Final mass:', next_mass.item())

    plot(loss_hist)
    plot(mass_hist)