Exemplo n.º 1
0
def gen_L(grid_width, grid_height, path='L_expert_trajectories'):
    ''' Generates trajectories of shape L, with right turn '''
    t = 3
    n = 2
    num_traj = 50

    obstacles = create_obstacles(grid_width, grid_height)
    set_diff = list(set(product(tuple(range(3, grid_width-3)),
                                tuple(range(3, grid_height-3)))) \
                                        - set(obstacles))

    T = TransitionFunction(grid_width, grid_height, obstacle_movement)
    expert_data_dict = {}
    # Number of goals is the same as number of actions
    num_actions, num_goals = 4, 4
    env_data_dict = {'num_actions': num_actions, 'num_goals': num_goals}

    for i in range(num_traj):
        start_state = State(sample_start(set_diff), obstacles)
        for action_idx in range(num_actions):

            path_key = str(i) + '_' + str(action_idx)
            expert_data_dict[path_key] = {
                'state': [],
                'action': [],
                'goal': []
            }

            state = start_state

            for j in range(n):
                # Set initial direction
                if j == 0:
                    action = Action(action_idx)
                else:
                    if action.delta == 0:
                        action = Action(3)
                    elif action.delta == 1:
                        action = Action(2)
                    elif action.delta == 2:
                        action = Action(0)
                    elif action.delta == 3:
                        action = Action(1)
                    else:
                        raise ValueError("Invalid action delta {}".format(
                            action.delta))

                for k in range(t):
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(action.delta)
                    state = T(state, action, j)
        # print(expert_data_dict[path_key]['goal'])

    return env_data_dict, expert_data_dict, obstacles, set_diff
Exemplo n.º 2
0
def gen_L(grid_width, grid_height, path='L_expert_trajectories'):
    ''' Generates trajectories of shape L, with right turn '''
    t = 3
    n = 2
    N = 200

    obstacles = create_obstacles(grid_width, grid_height)
    set_diff = list(
        set(
            product(tuple(range(3, grid_width -
                                3)), tuple(range(3, grid_height - 3)))) -
        set(obstacles))

    if not os.path.exists(path):
        os.makedirs(path)

    T = TransitionFunction(grid_width, grid_height, obstacle_movement)

    for i in range(N):
        filename = os.path.join(path, str(i) + '.txt')
        f = open(filename, 'w')
        for j in range(n):
            if j == 0:
                action = Action(random.choice(range(0, 4)))
                state = State(sample_start(set_diff), obstacles)
            else:  # take right turn
                if action.delta == 0:
                    action = Action(3)
                elif action.delta == 1:
                    action = Action(2)
                elif action.delta == 2:
                    action = Action(0)
                elif action.delta == 3:
                    action = Action(1)
            for k in range(t):
                f.write(' '.join([str(e)
                                  for e in state.state]) + '\n')  # write state
                f.write(
                    ' '.join([str(e)
                              for e in oned_to_onehot(action.delta, 4)]) +
                    '\n')  # write action
                f.write(
                    ' '.join([str(e)
                              for e in oned_to_onehot(action.delta, 4)]) +
                    '\n')  # write c[t]s
                state = T(state, action, j)

        f.close()
Exemplo n.º 3
0
def gen_sq_rec(grid_width, grid_height, path='SR_expert_trajectories'):
    ''' Generates squares if starting in quadrants 1 and 4, and rectangles if starting in quadransts 2 and 3 '''
    N = 200

    obstacles = create_obstacles(grid_width, grid_height)

    if not os.path.exists(path):
        os.makedirs(path)

    T = TransitionFunction(grid_width, grid_height, obstacle_movement)

    for i in range(N):
        filename = os.path.join(path, str(i) + '.txt')
        f = open(filename, 'w')
        half = random.choice(range(0, 2))
        if half == 0:  # left half
            set_diff = list(
                set(
                    product(tuple(range(0, (grid_width / 2) -
                                        3)), tuple(range(1, grid_height)))) -
                set(obstacles))
            start_loc = sample_start(set_diff)
        elif half == 1:  # right half
            set_diff = list(
                set(
                    product(tuple(range(grid_width / 2, grid_width -
                                        2)), tuple(range(2, grid_height)))) -
                set(obstacles))
            start_loc = sample_start(set_diff)

        state = State(start_loc, obstacles)

        if start_loc[0] >= grid_width / 2:  # quadrants 1 and 4
            # generate 2x2 square clockwise
            t = 2
            n = 4
            delta = 3

            for j in range(n):
                for k in range(t):
                    action = Action(delta)
                    f.write(' '.join([str(e) for e in state.state]) +
                            '\n')  # write state
                    f.write(' '.join(
                        [str(e) for e in oned_to_onehot(action.delta, 4)]) +
                            '\n')  # write action
                    f.write(' '.join(
                        [str(e) for e in oned_to_onehot(action.delta, 4)]) +
                            '\n')  # write c[t]s
                    state = T(state, action, j * 2 + k)

                if delta == 3:
                    delta = 1
                elif delta == 1:
                    delta = 2
                elif delta == 2:
                    delta = 0

        else:  # quadrants 2 and 3
            # generate 3x1 rectangle anti-clockwise
            t = [1, 3, 1, 3]
            delta = 1

            for j in range(len(t)):
                for k in range(t[j]):
                    action = Action(delta)
                    f.write(' '.join([str(e) for e in state.state]) +
                            '\n')  # write state
                    f.write(' '.join(
                        [str(e) for e in oned_to_onehot(action.delta, 4)]) +
                            '\n')  # write action
                    f.write(' '.join(
                        [str(e) for e in oned_to_onehot(action.delta, 4)]) +
                            '\n')  # write c[t]s
                    state = T(state, action, sum(t[0:j]) + k)

                if delta == 1:
                    delta = 3
                elif delta == 3:
                    delta = 0
                elif delta == 0:
                    delta = 2
Exemplo n.º 4
0
def gen_diverse_trajs(grid_width, grid_height):
    '''Generate diverse trajectories in a 21x21 grid with 4 goals.

    Return: Dictionary with keys as text filenames and values as dictionary.
        Each value dictionary contains two keys, 'states' with a list of states
        as value, and 'actions' with list of actions as value.
    '''

    assert grid_width == 21 and grid_height == 21, "Incorrect grid width height"
    N = 20
    goals = [(0, 0), (20, 20), (20, 0), (0, 20)]
    n_goals = len(goals)

    obstacles = create_obstacles(21, 21, 'diverse')

    T = TransitionFunction(grid_width, grid_height, obstacle_movement)

    set_diff = list(set(product(tuple(range(7,13)),tuple(range(7,13)))) \
            - set(obstacles))
    expert_data_dict = {}
    env_data_dict = {
        'num_actions': 8,
        'num_goals': n_goals,
        'goals': np.array(goals),
    }

    for n in range(N):

        start_state = State(sample_start(set_diff), obstacles)

        for g in range(n_goals):  # loop over goals
            # path 1 - go up/down till boundary and then move right/left

            if g == 0 or g == 2:  # do path 1 only for goal 0 and goal 2

                state = start_state
                path_key = str(n) + '_' + str(g) + '_' + str(1) + '.txt'
                expert_data_dict[path_key] = {
                    'state': [],
                    'action': [],
                    'goal': []
                }

                delta = 0 if g < 2 else 1
                action = Action(delta)

                while state.state[1] != grid_height - 1 and state.state[1] != 0:
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(g)
                    state = T(state, action, 0)

                delta = 3 if g == 0 or g == 3 else 2
                action = Action(delta)

                while state.state[0] != grid_width - 1 and state.state[0] != 0:
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(g)
                    state = T(state, action, 0)

                assert (state.coordinates in goals)

            # path 2 - go right/left till boundary and then move up/down

            if g == 1:  # do path 2 only for goal 1

                state = start_state
                path_key = str(n) + '_' + str(g) + '_' + str(2) + '.txt'
                expert_data_dict[path_key] = {
                    'state': [],
                    'action': [],
                    'goal': []
                }

                delta = 3 if g == 0 or g == 3 else 2
                action = Action(delta)

                while state.state[0] != grid_width - 1 and state.state[0] != 0:
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(g)
                    state = T(state, action, 0)

                delta = 0 if g < 2 else 1
                action = Action(delta)

                while state.state[1] != grid_height - 1 and state.state[1] != 0:
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(g)
                    state = T(state, action, 0)

                assert (state.coordinates in goals)

            # path 3 - go diagonally till obstacle and then
            #          move up/down if x > 10 or right/left if y > 10
            #          and then move right/left or up/down till goal

            if g == 3:  # do path 3 only for goal 3

                state = start_state
                path_key = str(n) + '_' + str(g) + '_' + str(3) + '.txt'
                expert_data_dict[path_key] = {
                    'state': [],
                    'action': [],
                    'goal': []
                }

                delta = g + 4
                action = Action(delta)

                while True:
                    new_state = T(state, action, 0)
                    if new_state.coordinates == state.coordinates:
                        break
                    expert_data_dict[path_key]['state'].append(state.state)
                    expert_data_dict[path_key]['action'].append(action.delta)
                    expert_data_dict[path_key]['goal'].append(g)
                    state = new_state

                if T(state, Action(2), 0).coordinates == state.coordinates \
                    or T(state, Action(3), 0).coordinates == state.coordinates:

                    delta = 0 if g < 2 else 1
                    action = Action(delta)

                    while state.state[1] != grid_height - 1 and state.state[
                            1] != 0:
                        expert_data_dict[path_key]['state'].append(state.state)
                        expert_data_dict[path_key]['action'].append(
                            action.delta)
                        expert_data_dict[path_key]['goal'].append(g)
                        state = T(state, action, 0)

                    delta = 3 if g == 0 or g == 3 else 2
                    action = Action(delta)

                    while state.state[0] != grid_width - 1 and state.state[
                            0] != 0:
                        expert_data_dict[path_key]['state'].append(state.state)
                        expert_data_dict[path_key]['action'].append(
                            action.delta)
                        expert_data_dict[path_key]['goal'].append(g)
                        state = T(state, action, 0)

                else:

                    delta = 3 if g == 0 or g == 3 else 2
                    action = Action(delta)

                    while state.state[0] != grid_width - 1 and state.state[
                            0] != 0:
                        expert_data_dict[path_key]['state'].append(state.state)
                        expert_data_dict[path_key]['action'].append(
                            action.delta)
                        expert_data_dict[path_key]['goal'].append(g)
                        state = T(state, action, 0)

                    delta = 0 if g < 2 else 1
                    action = Action(delta)

                    while state.state[1] != grid_height - 1 and state.state[
                            1] != 0:
                        expert_data_dict[path_key]['state'].append(state.state)
                        expert_data_dict[path_key]['action'].append(
                            action.delta)
                        expert_data_dict[path_key]['goal'].append(g)
                        state = T(state, action, 0)

                assert (state.coordinates in goals)

    return env_data_dict, expert_data_dict, obstacles, set_diff
Exemplo n.º 5
0
                Variable(torch.from_numpy(
                    ct).unsqueeze(0)).type(dtype)),
                1)).data.cpu().numpy()[0,0])

            if t < args.max_ep_length-1:
                reward += math.exp(np.sum(np.multiply(
                    posterior_net(torch.cat((
                        Variable(torch.from_numpy(
                            s.state).unsqueeze(0)).type(dtype),
                        Variable(torch.from_numpy(
                            oned_to_onehot(action)).unsqueeze(0)).type(dtype),
                        Variable(torch.from_numpy(
                            ct).unsqueeze(0)).type(dtype)),
                        1)).data.cpu().numpy()[0,:], c[t+1,:])))

            next_s = T(s, Action(action), R.t)
            true_reward = R(s, Action(action), ct)
            reward_sum += reward
            true_reward_sum += true_reward

            #next_state = running_state(next_state)

            mask = 1
            if t == args.max_ep_length-1:
                R.terminal = True
                mask = 0

            memory.push(s.state,
                        np.array([oned_to_onehot(action)]),
                        mask,
                        next_s.state,
Exemplo n.º 6
0
    def train_gail(self, expert):
        '''Train Info-GAIL.'''
        args, dtype = self.args, self.dtype
        results = {
            'average_reward': [],
            'episode_reward': [],
            'true_traj': {},
            'pred_traj': {}
        }
        self.train_step_count, self.gail_step_count = 0, 0

        for ep_idx in range(args.num_epochs):
            memory = Memory()

            num_steps = 0
            reward_batch, true_reward_batch = [], []
            expert_true_reward_batch = []
            true_traj_curr_episode, gen_traj_curr_episode = [], []

            while num_steps < args.batch_size:
                traj_expert = expert.sample(size=1)
                state_expert, action_expert, _, _ = traj_expert

                # Expert state and actions
                state_expert = state_expert[0]
                action_expert = action_expert[0]
                expert_episode_len = len(state_expert)

                # Sample start state or should we just choose the start state
                # from the expert trajectory sampled above.
                # curr_state_obj = self.sample_start_state()
                curr_state_obj = State(state_expert[0], self.obstacles)
                curr_state_feat = self.get_state_features(
                    curr_state_obj, self.args.use_state_features)

                # Add history to state
                if args.history_size > 1:
                    curr_state = -1 * np.ones(
                        (args.history_size * curr_state_feat.shape[0]),
                        dtype=np.float32)
                    curr_state[(args.history_size-1) \
                            * curr_state_feat.shape[0]:] = curr_state_feat
                else:
                    curr_state = curr_state_feat

                # TODO: Make this a separate function. Can be parallelized.
                ep_reward, ep_true_reward, expert_true_reward = 0, 0, 0
                true_traj, gen_traj = [], []
                gen_traj_dict = {
                    'features': [],
                    'actions': [],
                    'c': [],
                    'mask': []
                }
                disc_reward, posterior_reward = 0.0, 0.0
                # Use a hard-coded list for memory to gather experience since we
                # need to mutate it before finally creating a memory object.

                c_sampled = np.zeros((self.num_goals), dtype=np.float32)
                c_sampled[np.random.randint(0, self.num_goals)] = 1.0
                c_sampled_tensor = torch.zeros((1)).type(torch.LongTensor)
                c_sampled_tensor[0] = int(np.argmax(c_sampled))
                if self.args.cuda:
                    c_sampled_tensor = torch.cuda.LongTensor(c_sampled_tensor)

                memory_list = []
                for t in range(expert_episode_len):
                    action = self.select_action(
                        np.concatenate((curr_state, c_sampled)))
                    action_numpy = action.data.cpu().numpy()

                    # Save generated and true trajectories
                    true_traj.append((state_expert[t], action_expert[t]))
                    gen_traj.append((curr_state_obj.coordinates, action_numpy))
                    gen_traj_dict['features'].append(
                        self.get_state_features(curr_state_obj,
                                                self.args.use_state_features))
                    gen_traj_dict['actions'].append(action_numpy)
                    gen_traj_dict['c'].append(c_sampled)

                    action = epsilon_greedy_linear_decay(action_numpy,
                                                         args.num_epochs * 0.5,
                                                         ep_idx,
                                                         self.action_size,
                                                         low=0.05,
                                                         high=0.3)

                    # Get the discriminator reward
                    disc_reward_t = float(
                        self.reward_net(
                            torch.cat((Variable(
                                torch.from_numpy(curr_state).unsqueeze(
                                    0)).type(dtype),
                                       Variable(
                                           torch.from_numpy(
                                               oned_to_onehot(
                                                   action, self.action_size)).
                                           unsqueeze(0)).type(dtype)),
                                      1)).data.cpu().numpy()[0, 0])

                    if args.use_log_rewards and disc_reward_t < 1e-6:
                        disc_reward_t += 1e-6

                    disc_reward_t = -math.log(disc_reward_t) \
                            if args.use_log_rewards else -disc_reward_t
                    disc_reward += disc_reward_t

                    # Predict c given (x_t)
                    predicted_posterior = self.posterior_net(
                        Variable(torch.from_numpy(curr_state).unsqueeze(
                            0)).type(dtype))
                    posterior_reward_t = self.criterion_posterior(
                        predicted_posterior,
                        Variable(c_sampled_tensor)).data.cpu().numpy()[0]

                    posterior_reward += (self.args.lambda_posterior *
                                         posterior_reward_t)

                    # Update Rewards
                    ep_reward += (disc_reward_t + posterior_reward_t)
                    true_goal_state = [
                        int(x) for x in state_expert[-1].tolist()
                    ]
                    if self.args.flag_true_reward == 'grid_reward':
                        ep_true_reward += self.true_reward.reward_at_location(
                            curr_state_obj.coordinates,
                            goals=[true_goal_state])
                        expert_true_reward += self.true_reward.reward_at_location(
                            state_expert[t], goals=[true_goal_state])
                    elif self.args.flag_true_reward == 'action_reward':
                        ep_true_reward += self.true_reward.reward_at_location(
                            np.argmax(action_expert[t]), action)
                        expert_true_reward += self.true_reward.corret_action_reward
                    else:
                        raise ValueError("Incorrect true reward type")

                    # Update next state
                    next_state_obj = self.transition_func(
                        curr_state_obj, Action(action), 0)
                    next_state_feat = self.get_state_features(
                        next_state_obj, self.args.use_state_features)
                    #next_state = running_state(next_state)

                    mask = 0 if t == expert_episode_len - 1 else 1

                    # Push to memory
                    memory_list.append([
                        curr_state,
                        np.array([oned_to_onehot(action,
                                                 self.action_size)]), mask,
                        next_state_feat, disc_reward_t + posterior_reward_t,
                        c_sampled, c_sampled
                    ])

                    if args.render:
                        env.render()

                    if not mask:
                        break

                    curr_state_obj = next_state_obj
                    curr_state_feat = next_state_feat

                    if args.history_size > 1:
                        curr_state[:(args.history_size-1) \
                                * curr_state_feat.shape[0]] = \
                                curr_state[curr_state_feat.shape[0]:]
                        curr_state[(args.history_size-1) \
                                * curr_state_feat.shape[0]:] = curr_state_feat
                    else:
                        curr_state = curr_state_feat



                assert memory_list[-1][2] == 0, \
                        "Mask for final end state is not 0."
                for memory_t in memory_list:
                    memory.push(*memory_t)

                self.logger.summary_writer.add_scalars(
                    'gen_traj/gen_reward', {
                        'discriminator': disc_reward,
                        'posterior': posterior_reward,
                    }, self.train_step_count)

                num_steps += (t - 1)
                reward_batch.append(ep_reward)
                true_reward_batch.append(ep_true_reward)
                expert_true_reward_batch.append(expert_true_reward)
                results['episode_reward'].append(ep_reward)

                # Append trajectories
                true_traj_curr_episode.append(true_traj)
                gen_traj_curr_episode.append(gen_traj)

            results['average_reward'].append(np.mean(reward_batch))

            # Add to tensorboard
            self.logger.summary_writer.add_scalars(
                'gen_traj/reward', {
                    'average': np.mean(reward_batch),
                    'max': np.max(reward_batch),
                    'min': np.min(reward_batch)
                }, self.train_step_count)
            self.logger.summary_writer.add_scalars(
                'gen_traj/true_reward', {
                    'average': np.mean(true_reward_batch),
                    'max': np.max(true_reward_batch),
                    'min': np.min(true_reward_batch),
                    'expert_true': np.mean(expert_true_reward_batch)
                }, self.train_step_count)

            # Add predicted and generated trajectories to results
            if ep_idx % self.args.save_interval == 0:
                results['true_traj'][ep_idx] = copy.deepcopy(
                    true_traj_curr_episode)
                results['pred_traj'][ep_idx] = copy.deepcopy(
                    gen_traj_curr_episode)

            # Update parameters
            gen_batch = memory.sample()

            # We do not get the context variable from expert trajectories.
            # Hence we need to fill it in later.
            expert_batch = expert.sample(size=args.num_expert_trajs)

            self.update_params(gen_batch, expert_batch, ep_idx,
                               args.optim_epochs, args.optim_batch_size)

            self.train_step_count += 1

            if ep_idx > 0 and ep_idx % args.log_interval == 0:
                print('Episode [{}/{}]  Avg R: {:.2f}   Max R: {:.2f} \t' \
                      'True Avg {:.2f}   True Max R: {:.2f}   ' \
                      'Expert (Avg): {:.2f}'.format(
                          ep_idx, args.num_epochs, np.mean(reward_batch),
                          np.max(reward_batch), np.mean(true_reward_batch),
                          np.max(true_reward_batch),
                          np.mean(expert_true_reward_batch)))

            results_path = os.path.join(args.results_dir, 'results.pkl')
            with open(results_path, 'wb') as results_f:
                pickle.dump((results), results_f, protocol=2)
                # print("Did save results to {}".format(results_path))

            if ep_idx % args.save_interval == 0:
                checkpoint_filepath = self.model_checkpoint_filepath(ep_idx)
                torch.save(self.checkpoint_data_to_save(), checkpoint_filepath)
                print("Did save checkpoint: {}".format(checkpoint_filepath))
Exemplo n.º 7
0
def test(Transition):
    model.eval()
    #test_loss = 0

    for _ in range(20):
        c = expert.sample_c()
        N = c.shape[0]
        c = np.argmax(c[0, :])
        if args.expert_path == 'SR_expert_trajectories/':
            if c == 1:
                half = 0
            elif c == 3:
                half = 1
        elif args.expert_path == 'SR2_expert_trajectories/':
            half = c
        if args.expert_path == 'SR_expert_trajectories/' or args.expert_path == 'SR2_expert_trajectories/':
            if half == 0:  # left half
                set_diff = list(
                    set(
                        product(tuple(range(0, (width / 2) -
                                            3)), tuple(range(1, height)))) -
                    set(obstacles))
            elif half == 1:  # right half
                set_diff = list(
                    set(
                        product(tuple(range(width / 2, width -
                                            2)), tuple(range(2, height)))) -
                    set(obstacles))
        else:
            set_diff = list(
                set(product(tuple(range(3, width - 3)), repeat=2)) -
                set(obstacles))

        start_loc = sample_start(set_diff)
        s = State(start_loc, obstacles)
        R.reset()
        c = torch.from_numpy(np.array([-1.0, c])).unsqueeze(0).float()

        print 'c is ', c[0, 1]

        c = Variable(c)

        x = -1 * torch.ones(1, 4, 2)

        if args.cuda:
            x = x.cuda()
            c = c.cuda()

        for t in range(N):

            x[:, :3, :] = x[:, 1:, :]
            curr_x = torch.from_numpy(s.state).unsqueeze(0)
            if args.cuda:
                curr_x = curr_x.cuda()

            x[:, 3:, :] = curr_x

            x_t0 = Variable(x[:, 0, :])
            x_t1 = Variable(x[:, 1, :])
            x_t2 = Variable(x[:, 2, :])
            x_t3 = Variable(x[:, 3, :])

            mu, logvar = model.encode(torch.cat((x_t0, x_t1, x_t2, x_t3), 1),
                                      c)
            c[:, 0] = model.reparameterize(mu, logvar)
            pred_a = model.decode(torch.cat((x_t0, x_t1, x_t2, x_t3), 1),
                                  c).data.cpu().numpy()
            pred_a = np.argmax(pred_a)
            print pred_a
            next_s = Transition(s, Action(pred_a), R.t)

            s = next_s
Exemplo n.º 8
0
def train(epoch, expert, Transition):
    model.train()
    train_loss = 0
    for batch_idx in range(10):  # 10 batches per epoch
        batch = expert.sample(args.batch_size)
        x_data = torch.Tensor(batch.state)
        N = x_data.size(1)
        x = -1 * torch.ones(x_data.size(0), 4, x_data.size(2))
        x[:, 3, :] = x_data[:, 0, :]

        a = Variable(torch.Tensor(batch.action))

        _, c2 = torch.Tensor(batch.c).max(2)
        c2 = c2.float()[:, 0].unsqueeze(1)
        c1 = -1 * torch.ones(c2.size())
        c = torch.cat((c1, c2), 1)

        #c_t0 = Variable(c[:,0].clone().view(c.size(0), 1))

        if args.cuda:
            a = a.cuda()
            #c_t0 = c_t0.cuda()

        optimizer.zero_grad()
        for t in range(N):
            x_t0 = Variable(x[:, 0, :].clone().view(x.size(0), x.size(2)))
            x_t1 = Variable(x[:, 1, :].clone().view(x.size(0), x.size(2)))
            x_t2 = Variable(x[:, 2, :].clone().view(x.size(0), x.size(2)))
            x_t3 = Variable(x[:, 3, :].clone().view(x.size(0), x.size(2)))
            c_t0 = Variable(c)

            if args.cuda:
                x_t0 = x_t0.cuda()
                x_t1 = x_t1.cuda()
                x_t2 = x_t2.cuda()
                x_t3 = x_t3.cuda()
                c_t0 = c_t0.cuda()

            recon_batch, mu, logvar = model(x_t0, x_t1, x_t2, x_t3, c_t0)
            loss = loss_function(recon_batch, a[:, t, :], mu, logvar)
            loss.backward()
            train_loss += loss.data[0]

            pred_actions = recon_batch.data.cpu().numpy()

            x[:, :3, :] = x[:, 1:, :]
            # get next state and update x
            for b_id in range(pred_actions.shape[0]):
                action = Action(np.argmax(pred_actions[b_id, :]))
                state = State(x[b_id, 3, :].cpu().numpy(), obstacles)
                next_state = Transition(state, action, 0)
                x[b_id, 3, :] = torch.Tensor(next_state.state)

            # update c
            c[:, 0] = model.reparameterize(mu, logvar).data.cpu()

        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, 200.0,
                100. * batch_idx / 20.0, loss.data[0] / args.batch_size))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / 200.0))