def _enjoy(): model = Model(action_dim=2, max_action=1.) try: state_dict = torch.load('trained_models/imitate.pt', map_location=device) model.load_state_dict(state_dict) except: print('failed to load model') exit() model.eval().to(device) env = launch_env() env = ResizeWrapper(env) env = NormalizeWrapper(env) env = ImgWrapper(env) env = ActionWrapper(env) env = DtRewardWrapper(env) obs = env.reset() while True: obs = torch.from_numpy(obs).float().to(device).unsqueeze(0) action = model(obs) action = action.squeeze().data.cpu().numpy() obs, reward, done, info = env.step(action) env.render() if done: if reward < 0: print('*** FAILED ***') time.sleep(0.7) obs = env.reset() env.render()
def _train(args): env = launch_env() env = ResizeWrapper(env) env = NormalizeWrapper(env) env = ImgWrapper(env) env = DtRewardWrapper(env) env = MetricsWrapper(env) env = ActionWrapper(env) print("Initialized Wrappers") observation_shape = (None, ) + env.observation_space.shape action_shape = (None, ) + env.action_space.shape # Create an imperfect demonstrator expert = PurePursuitExpert(env=env) observations = [] actions = [] # let's collect our samples for episode in range(0, args.episodes): print("Starting episode", episode) for steps in range(0, args.steps): # use our 'expert' to predict the next action. action = expert.predict(None) observation, reward, done, info = env.step(action) observations.append(observation) actions.append(action) env.reset() env.close() actions = np.array(actions) observations = np.array(observations) model = Model(action_dim=2, max_action=1.) model.train().to(device) # weight_decay is L2 regularization, helps avoid overfitting optimizer = optim.SGD(model.parameters(), lr=0.0004, weight_decay=1e-3) avg_loss = 0 for epoch in range(args.epochs): optimizer.zero_grad() batch_indices = np.random.randint(0, observations.shape[0], (args.batch_size)) obs_batch = torch.from_numpy( observations[batch_indices]).float().to(device) act_batch = torch.from_numpy(actions[batch_indices]).float().to(device) model_actions = model(obs_batch) loss = (model_actions - act_batch).norm(2).mean() loss.backward() optimizer.step() loss = loss.data[0] avg_loss = avg_loss * 0.995 + loss * 0.005 print('epoch %d, loss=%.3f' % (epoch, avg_loss)) # Periodically save the trained model if epoch % 200 == 0: torch.save(model.state_dict(), 'imitation/pytorch/models/imitate.pt')