def main(args):
    if not os.path.isdir(args.result_dir):
        os.makedirs(args.result_dir)
    parent = os.path.dirname(os.path.abspath(__file__))
    # load test xml files
    test_file = os.path.join(parent, 'tests/test_xmls/temp_1_{}.pickle'.format(args.case))
    params = pickle.load(open(test_file, 'rb'))
    # params = params[:6]
    if args.shuffle:
        random.shuffle(params)

    num_test = len(params)
    print('                    ++++++++++++++++++++++++++')
    print('                    +++ Now running case {} +++'.format(args.case))
    print('                    ++++++++++++++++++++++++++\n\n')

    policy_net = DRQN(args.indim, args.outdim)
    policy_net.load_state_dict(torch.load(args.weight_path)['policy_net_1'])
    policy_net.eval()

    # load normalizer
    # Setup the state normalizer
    normalizer = Multimodal_Normalizer(num_inputs = args.indim, device=args.device)
    if args.normalizer_file:
        if os.path.exists(args.normalizer_file):
            normalizer.restore_state(args.normalizer_file)

    # Create robot, reset simulation and grasp handle
    model = load_model_from_path(args.model_path)
    sim = MjSim(model)
    sim_param = SimParameter(sim)
    sim.step()
    if args.render:
        viewer = MjViewer(sim)
    else:
        viewer = None

    robot = RobotSim(sim, viewer, sim_param, args.render, args.break_thresh)

    tactile_obs_space = TactileObs(robot.get_gripper_xpos(),            # 24
                         robot.get_all_touch_buffer(args.hap_sample))   # 30 x 6

    performance = {'time':[], 'success':[], 'num_broken':[], 'tendon_hist':[0,0,0,0,0], 'collision':[], 
                    'action_hist': [0,0,0,0,0,0]}
    
    for i in range(num_test):
        velcro_params = params[i]
        geom, origin_offset, euler, radius = velcro_params
        print('\n\nTest {} Velcro parameters are: {}, {}, {}, {}'.format(i, geom, origin_offset, euler, radius))
        change_sim(robot.mj_sim, geom, origin_offset, euler, radius)
        robot.reset_simulation()
        ret = robot.grasp_handle()
        performance = test_network(args, policy_net, normalizer, robot, tactile_obs_space, performance)
        print('Success: {}, time: {}, num_broken: {}, collision:{} '.format(
                performance['success'][-1], performance['time'][-1], performance['num_broken'][-1], performance['collision'][-1]))

    print('Finished opening velcro with haptics test \n')
    success = np.array(performance['success'])
    time = np.array(performance['time'])
    print('Successfully opened the velcro in: {}% of cases'.format(100 * np.sum(success) / len(performance['success'])))
    print('Average time to open: {}'.format(np.average(time[success>0])))
    print('Action histogram for the test is: {}'.format(performance['action_hist']))

    out_fname = 'case{}.txt'.format(args.case)
    with open(os.path.join(args.result_dir, out_fname), 'w+') as f:
        f.write('Time: {}\n'.format(performance['time']))
        f.write('Success: {}\n'.format(performance['success']))
        f.write('Successfully opened the velcro in: {}% of cases\n'.format(100 * np.sum(success) / len(performance['success'])))
        f.write('Average time to open: {}\n'.format(np.average(time[success>0])))
        f.write('Num_broken: {}\n'.format(performance['num_broken']))
        f.write('Tendon histogram: {}\n'.format(performance['tendon_hist']))
        f.write('collision: {}\n'.format(performance['collision']))
        # f.write('high_success: {} low_success: {} '.format(high_success, low_success))
    f.close()
class POMDP:
    def __init__(self, args):
        self.args = args
        self.ACTIONS = ['left', 'right', 'forward', 'backward', 'up',
                        'down']  # 'open', 'close']
        self.P_START = 0.999
        self.P_END = 0.05
        self.P_DECAY = 600
        self.max_iter = args.max_iter
        self.gripping_force = args.grip_force
        self.break_threshold = args.break_thresh

        # Prepare the drawing figure
        fig, (ax1, ax2) = plt.subplots(1, 2)
        self.figure = (fig, ax1, ax2)

    # Function to select an action from our policy or a random one
    def select_action(self, observation, hidden_state, cell_state):
        sample = random.random()
        p_threshold = self.P_END + (self.P_START - self.P_END) * math.exp(
            -1. * self.steps_done / self.P_DECAY)

        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            self.policy_net_1.eval()
            torch_observation = torch.from_numpy(observation).float().to(
                self.args.device).unsqueeze(0)
            model_out = self.policy_net_1(torch_observation,
                                          batch_size=1,
                                          time_step=1,
                                          hidden_state=hidden_state,
                                          cell_state=cell_state)
            out = model_out[0]
            hidden_state = model_out[1][0]
            cell_state = model_out[1][1]
            self.policy_net_1.train()

            if sample > p_threshold:
                action = int(torch.argmax(out[0]))
                return action, hidden_state, cell_state
            else:
                return random.randrange(
                    0, self.args.outdim), hidden_state, cell_state

    def optimize_model(self):
        args = self.args
        if len(self.memory) < (args.batch_size):
            return

        hidden_batch_1, cell_batch_1 = self.policy_net_1.init_hidden_states(
            args.batch_size, args.device)
        hidden_batch_2, cell_batch_2 = self.policy_net_2.init_hidden_states(
            args.batch_size, args.device)
        hidden_batch_3, cell_batch_3 = self.policy_net_3.init_hidden_states(
            args.batch_size, args.device)
        batch = self.memory.sample(args.batch_size, args.time_step)
        if not batch:
            return

        current_states, actions, rewards, next_states = [], [], [], []
        for b in batch:
            cs, ac, rw, ns = [], [], [], []
            for element in b:
                cs.append(element.state)
                ac.append(element.action)
                rw.append(element.reward)
                ns.append(element.next_state)
            current_states.append(cs)
            actions.append(ac)
            rewards.append(rw)
            next_states.append(ns)

        current_states = torch.from_numpy(np.array(current_states)).float().to(
            args.device)
        actions = torch.from_numpy(np.array(actions)).long().to(args.device)
        rewards = torch.from_numpy(np.array(rewards)).float().to(args.device)
        next_states = torch.from_numpy(np.array(next_states)).float().to(
            args.device)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        Q_s_1, _ = self.policy_net_1.forward(current_states,
                                             batch_size=args.batch_size,
                                             time_step=args.time_step,
                                             hidden_state=hidden_batch_1,
                                             cell_state=cell_batch_1)
        Q_s_a_1 = Q_s_1.gather(
            dim=1,
            index=actions[:,
                          args.time_step - 1].unsqueeze(dim=1)).squeeze(dim=1)

        Q_s_2, _ = self.policy_net_2.forward(current_states,
                                             batch_size=args.batch_size,
                                             time_step=args.time_step,
                                             hidden_state=hidden_batch_2,
                                             cell_state=cell_batch_2)
        Q_s_a_2 = Q_s_2.gather(
            dim=1,
            index=actions[:,
                          args.time_step - 1].unsqueeze(dim=1)).squeeze(dim=1)

        Q_s_3, _ = self.policy_net_3.forward(current_states,
                                             batch_size=args.batch_size,
                                             time_step=args.time_step,
                                             hidden_state=hidden_batch_3,
                                             cell_state=cell_batch_3)
        Q_s_a_3 = Q_s_3.gather(
            dim=1,
            index=actions[:,
                          args.time_step - 1].unsqueeze(dim=1)).squeeze(dim=1)

        Q_next_1, _ = self.policy_net_1.forward(next_states,
                                                batch_size=args.batch_size,
                                                time_step=args.time_step,
                                                hidden_state=hidden_batch_1,
                                                cell_state=cell_batch_1)
        Q_next_max_1 = Q_next_1.detach().max(dim=1)[0]

        Q_next_2, _ = self.policy_net_2.forward(next_states,
                                                batch_size=args.batch_size,
                                                time_step=args.time_step,
                                                hidden_state=hidden_batch_2,
                                                cell_state=cell_batch_2)
        Q_next_max_2 = Q_next_2.detach().max(dim=1)[0]

        Q_next_3, _ = self.policy_net_3.forward(next_states,
                                                batch_size=args.batch_size,
                                                time_step=args.time_step,
                                                hidden_state=hidden_batch_3,
                                                cell_state=cell_batch_3)
        Q_next_max_3 = Q_next_3.detach().max(dim=1)[0]

        Q_next_max = torch.min(torch.min(Q_next_max_1, Q_next_max_2),
                               Q_next_max_3)

        # Compute the expected Q values
        target_values = rewards[:,
                                args.time_step - 1] + (args.gamma * Q_next_max)

        # Compute Huber loss
        loss_1 = F.smooth_l1_loss(Q_s_a_1, target_values)
        loss_2 = F.smooth_l1_loss(Q_s_a_2, target_values)
        loss_3 = F.smooth_l1_loss(Q_s_a_3, target_values)

        # Optimize the model
        self.optimizer_1.zero_grad()
        self.optimizer_2.zero_grad()
        self.optimizer_3.zero_grad()
        loss_1.backward()
        loss_2.backward()
        loss_3.backward()
        for param in self.policy_net_1.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.policy_net_2.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.policy_net_3.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer_1.step()
        self.optimizer_2.step()
        self.optimizer_3.step()
        return [loss_1, loss_2, loss_3]

    def train_POMDP(self):
        args = self.args
        ROOT_DIR = os.path.dirname(os.path.dirname(
            os.path.abspath(__file__)))  # corl2019
        PARENT_DIR = os.path.dirname(ROOT_DIR)  # reserach
        # Create the output directory if it does not exist
        output_dir = os.path.join(PARENT_DIR, 'multistep_pomdp',
                                  args.output_dir)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # write args to file
        with open(os.path.join(output_dir, 'args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
        f.close()

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_3 = DRQN(args.indim, args.outdim).to(args.device)

        # Set up the optimizer
        self.optimizer_1 = optim.RMSprop(self.policy_net_1.parameters())
        self.optimizer_2 = optim.RMSprop(self.policy_net_2.parameters())
        self.optimizer_3 = optim.RMSprop(self.policy_net_3.parameters())
        self.memory = RecurrentMemory(800)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.checkpoint_file:
            if os.path.exists(args.checkpoint_file):
                checkpoint = torch.load(args.checkpoint_file)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.policy_net_3.load_state_dict(checkpoint['policy_net_3'])
                self.optimizer_1.load_state_dict(checkpoint['optimizer_1'])
                self.optimizer_2.load_state_dict(checkpoint['optimizer_2'])
                self.optimizer_3.load_state_dict(checkpoint['optimizer_3'])
                start_episode = checkpoint['epochs']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.checkpoint_file),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)

        # Main training loop
        for ii in range(start_episode, args.epochs):
            start_time = time.time()
            self.steps_done += 1
            act_sequence = []
            if args.sim:
                sim_params = init_model(robot.mj_sim)
                robot.reset_simulation()
                ret = robot.grasp_handle()
                if not ret:
                    continue

                # Local memory for current episode
                localMemory = []

                # Get current observation
                hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                    batch_size=1, device=args.device)
                hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                    batch_size=1, device=args.device)
                hidden_state_3, cell_state_3 = self.policy_net_3.init_hidden_states(
                    batch_size=1, device=args.device)
                observation_space = TactileObs(
                    robot.get_gripper_xpos(),  # 24
                    robot.get_all_touch_buffer(args.hap_sample))  # 30 x 7
                broken_so_far = 0

            for t in count():
                if not args.quiet and t % 50 == 0:
                    print("Running training episode: {}, iteration: {}".format(
                        ii, t))

                # Select action
                observation = observation_space.get_state()
                if args.position:
                    observation = observation[6:]
                if args.shear:
                    indices = np.ones(len(observation), dtype=bool)
                    indices[6:166] = False
                    observation = observation[indices]
                if args.force:
                    observation = observation[:166]
                normalizer.observe(observation)
                observation = normalizer.normalize(observation)
                action, hidden_state_1, cell_state_1 = self.select_action(
                    observation, hidden_state_1, cell_state_1)

                # record actions in this epoch
                act_sequence.append(action)

                # Perform action
                delta = action_space.get_action(
                    self.ACTIONS[action])['delta'][:3]
                target_position = np.add(robot.get_gripper_jpos()[:3],
                                         np.array(delta))
                target_pose = np.hstack(
                    (target_position, robot.get_gripper_jpos()[3:]))

                if args.sim:
                    robot.move_joint(target_pose,
                                     True,
                                     self.gripping_force,
                                     hap_sample=args.hap_sample)

                    # Get reward
                    done, num = robot.update_tendons()
                    failure = robot.check_slippage()
                    if num > broken_so_far:
                        reward = num - broken_so_far
                        broken_so_far = num
                    else:
                        reward = 0

                    # # Add a movement reward
                    # reward -= 0.05 * np.linalg.norm(target_position - robot.get_gripper_jpos()[:3]) / np.linalg.norm(delta)

                    # Observe new state
                    observation_space.update(
                        robot.get_gripper_xpos(),  # 24
                        robot.get_all_touch_buffer(args.hap_sample))  # 30x7

                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                # Check if done
                if not done and not failure:
                    next_state = observation_space.get_state()
                    if args.position:
                        next_state = next_state[6:]
                    if args.shear:
                        indices = np.ones(len(next_state), dtype=bool)
                        indices[6:166] = False
                        next_state = next_state[indices]
                    if args.force:
                        next_state = next_state[:166]
                    normalizer.observe(next_state)
                    next_state = normalizer.normalize(next_state)
                else:
                    next_state = None

                # Push new Transition into memory
                localMemory.append(
                    Transition(observation, action, next_state, reward))

                # Optimize the model
                if t % 10 == 0:
                    loss = self.optimize_model()
        #        if loss:
        #            print_variables['loss'].append(loss.item())

        # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(sim_params))
                    print("Actions in this epoch are: {}".format(act_sequence))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(
                    output_dir, 'checkpoint_model_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'policy_net_3': self.policy_net_3.state_dict(),
                        'optimizer_1': self.optimizer_1.state_dict(),
                        'optimizer_2': self.optimizer_2.state_dict(),
                        'optimizer_3': self.optimizer_3.state_dict(),
                    }, save_path)

            self.memory.save_memory(os.path.join(output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables
class POMDP:
    def __init__(self, args):
        self.args = args
        self.ACTIONS = ['left', 'right', 'forward', 'backward', 'up',
                        'down']  # 'open', 'close']
        self.P_START = 0.999
        self.P_END = 0.1
        self.P_DECAY = 60000
        self.max_iter = args.max_iter
        self.gripping_force = args.grip_force
        self.break_threshold = args.break_thresh

        # Prepare the drawing figure
        fig, (ax1, ax2) = plt.subplots(1, 2)
        self.figure = (fig, ax1, ax2)

    # Function to select an action from our policy or a random one
    def select_action(self, observation, hidden_state, cell_state):
        args = self.args
        sample = random.random()
        p_threshold = self.P_END + (self.P_START - self.P_END) * math.exp(
            -1. * self.steps_done / self.P_DECAY)
        self.steps_done += 1

        tactile_obs, img_norm = observation

        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            self.policy_net_1.eval()
            torch_tactile_obs = torch.from_numpy(tactile_obs).float().to(
                args.device).unsqueeze(0)
            torch_img_norm = torch.from_numpy(img_norm).float().to(
                args.device).unsqueeze(0)
            img_ft = self.conv_net_1.forward(torch_img_norm)

            torch_observation = torch.cat((torch_tactile_obs, img_ft), dim=1)
            model_out = self.policy_net_1(torch_observation,
                                          batch_size=1,
                                          time_step=1,
                                          hidden_state=hidden_state,
                                          cell_state=cell_state)
            out = model_out[0]
            hidden_state = model_out[1][0]
            cell_state = model_out[1][1]
            self.policy_net_1.train()

            if sample > p_threshold:
                action = int(torch.argmax(out[0]))
                return action, hidden_state, cell_state
            else:
                return random.randrange(0,
                                        args.outdim), hidden_state, cell_state

    def sample_memory(self):
        batch = self.memory.sample(self.args.batch_size, self.args.time_step)
        if not batch:
            return

        current_states, actions, rewards, next_states = [], [], [], []
        for b in batch:
            cs, ac, rw, ns = [], [], [], []
            for element in b:
                cs.append(element.state)
                ac.append(element.action)
                rw.append(element.reward)
                ns.append(element.next_state)
            current_states.append(cs)
            actions.append(ac)
            rewards.append(rw)
            next_states.append(ns)
        return current_states, actions, rewards, next_states

    def extract_ft(self, conv_net, obs_batch):
        args = self.args
        assert len(obs_batch) == args.batch_size
        assert len(obs_batch[0]) == args.time_step
        result = []
        for b in obs_batch:
            obs_squence = torch.tensor([]).to(args.device)
            for item in b:
                tactile_obs, img_norm = item
                torch_tactile_obs = torch.from_numpy(tactile_obs).float().to(
                    args.device)
                torch_img_norm = torch.from_numpy(img_norm).float().to(
                    args.device).unsqueeze(0)
                img_ft = conv_net.forward(torch_img_norm)
                torch_full_obs = torch.cat((torch_tactile_obs, img_ft[0]))
                obs_squence = torch.cat((obs_squence, torch_full_obs))
            torch_obs_squence = obs_squence.view(args.time_step, -1)
            result.append(obs_squence)
        return torch.stack(result)

    def bootstrap(self, policy_net, conv_net, memory_subset):
        args = self.args
        if len(self.memory) < (args.batch_size):
            return

        current_states, actions, rewards, next_states = memory_subset

        # process observation (tactile_obs, img) to 1d tensor of (tactile_obs, feature)
        # and then stack them to corresponding dimension: (batch_size, time_step, *)
        current_states = self.extract_ft(conv_net, current_states)
        next_states = self.extract_ft(conv_net, next_states)

        # convert all to torch tensors
        actions = torch.from_numpy(np.array(actions)).long().to(args.device)
        rewards = torch.from_numpy(np.array(rewards)).float().to(args.device)

        hidden_batch, cell_batch = policy_net.init_hidden_states(
            args.batch_size, args.device)

        Q_s, _ = policy_net.forward(current_states,
                                    batch_size=args.batch_size,
                                    time_step=args.time_step,
                                    hidden_state=hidden_batch,
                                    cell_state=cell_batch)
        Q_s_a = Q_s.gather(dim=1,
                           index=actions[:, args.time_step -
                                         1].unsqueeze(dim=1)).squeeze(dim=1)

        Q_next, _ = policy_net.forward(next_states,
                                       batch_size=args.batch_size,
                                       time_step=args.time_step,
                                       hidden_state=hidden_batch,
                                       cell_state=cell_batch)
        Q_next_max = Q_next.detach().max(dim=1)[0]
        return Q_s_a, Q_next_max

    def optimize(self):
        args = self.args
        if len(self.memory) < (args.batch_size):
            return

        memory_subset = self.sample_memory()
        _, _, rewards, _ = memory_subset
        rewards = torch.from_numpy(np.array(rewards)).float().to(args.device)

        Q_s_a_1, Q_next_max_1 = self.bootstrap(self.policy_net_1,
                                               self.conv_net_1, memory_subset)
        Q_s_a_2, Q_next_max_2 = self.bootstrap(self.policy_net_2,
                                               self.conv_net_2, memory_subset)

        Q_next_max = torch.min(Q_next_max_1, Q_next_max_2)

        # Compute the expected Q values
        target_values = rewards[:,
                                args.time_step - 1] + (args.gamma * Q_next_max)

        # Compute Huber loss
        loss_1 = F.smooth_l1_loss(Q_s_a_1, target_values)
        loss_2 = F.smooth_l1_loss(Q_s_a_2, target_values)

        # Optimize the model
        self.policy_optimizer_1.zero_grad()
        self.policy_optimizer_2.zero_grad()
        self.conv_optimizer_1.zero_grad()
        self.conv_optimizer_2.zero_grad()
        loss_1.backward()
        loss_2.backward()
        for param in self.policy_net_1.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.policy_net_2.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.conv_net_2.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.conv_net_2.parameters():
            param.grad.data.clamp_(-1, 1)

        self.policy_optimizer_1.step()
        self.policy_optimizer_2.step()
        self.conv_optimizer_1.step()
        self.conv_optimizer_2.step()

    def train_POMDP(self):
        args = self.args
        # Create the output directory if it does not exist
        if not os.path.isdir(args.output_dir):
            os.makedirs(args.output_dir)

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.indim, args.outdim).to(args.device)
        self.conv_net_1 = ConvNet(args.ftdim, args.depth).to(args.device)
        self.conv_net_2 = ConvNet(args.ftdim, args.depth).to(args.device)

        # Set up the optimizer
        self.policy_optimizer_1 = optim.RMSprop(self.policy_net_1.parameters(),
                                                lr=args.lr)
        self.policy_optimizer_2 = optim.RMSprop(self.policy_net_2.parameters(),
                                                lr=args.lr)
        self.conv_optimizer_1 = optim.RMSprop(self.conv_net_1.parameters(),
                                              lr=1e-5)
        self.conv_optimizer_2 = optim.RMSprop(self.conv_net_2.parameters(),
                                              lr=1e-5)
        self.memory = RecurrentMemory(70)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim - args.ftdim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.checkpoint_file:
            if os.path.exists(args.checkpoint_file):
                checkpoint = torch.load(args.checkpoint_file)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.conv_net_1.load_state_dict(checkpoint['conv_net_1'])
                self.conv_net_2.load_state_dict(checkpoint['conv_net_2'])
                self.policy_optimizer_1.load_state_dict(
                    checkpoint['policy_optimizer_1'])
                self.policy_optimizer_2.load_state_dict(
                    checkpoint['policy_optimizer_2'])
                self.conv_optimizer_1.load_state_dict(
                    checkpoint['conv_optimizer_1'])
                self.conv_optimizer_2.load_state_dict(
                    checkpoint['conv_optimizer_2'])
                start_episode = checkpoint['epoch']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.checkpoint_file),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        if args.weight_conv:
            checkpoint = torch.load(args.weight_conv)
            self.conv_net_1.load_state_dict(checkpoint['conv_net'])
            self.conv_optimizer_1.load_state_dict(checkpoint['conv_optimizer'])
            self.conv_net_2.load_state_dict(checkpoint['conv_net'])
            self.conv_optimizer_2.load_state_dict(checkpoint['conv_optimizer'])

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)
        tactile_obs_space = TactileObs(
            robot.get_gripper_jpos(),  # 6
            robot.get_all_touch_buffer(args.hap_sample))  # 30 x 12

        # Main training loop
        for ii in range(start_episode, args.epochs):
            start_time = time.time()
            act_sequence = []
            velcro_params = init_model(robot.mj_sim)
            robot.reset_simulation()
            ret = robot.grasp_handle()
            if not ret:
                continue

            # Local memory for current episode
            localMemory = []

            # Get current observation
            hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                batch_size=1, device=args.device)
            hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                batch_size=1, device=args.device)

            broken_so_far = 0

            for t in count():
                if not args.quiet and t % 50 == 0:
                    print("Running training episode: {}, iteration: {}".format(
                        ii, t))

                # Select action
                tactile_obs = tactile_obs_space.get_state()
                normalizer.observe(tactile_obs)
                tactile_obs = normalizer.normalize(tactile_obs)
                # Get image and normalize it
                img = robot.get_img(args.img_w, args.img_h, 'c1', args.depth)
                if args.depth:
                    depth = norm_depth(img[1])
                    img = norm_img(img[0])
                    img_norm = np.empty((4, args.img_w, args.img_h))
                    img_norm[:3, :, :] = img
                    img_norm[3, :, :] = depth
                else:
                    img_norm = norm_img(img)

                observation = [tactile_obs, img_norm]
                action, hidden_state_1, cell_state_1 = self.select_action(
                    observation, hidden_state_1, cell_state_1)

                # record actions in this epoch
                act_sequence.append(action)

                # Perform action
                delta = action_space.get_action(
                    self.ACTIONS[action])['delta'][:3]
                target_position = np.add(robot.get_gripper_jpos()[:3],
                                         np.array(delta))
                target_pose = np.hstack(
                    (target_position, robot.get_gripper_jpos()[3:]))
                robot.move_joint(target_pose,
                                 True,
                                 self.gripping_force,
                                 hap_sample=args.hap_sample)

                # Get reward
                done, num = robot.update_tendons()
                failure = robot.check_slippage()
                if num > broken_so_far:
                    reward = num - broken_so_far
                    broken_so_far = num
                else:
                    reward = 0

                # Observe new state
                tactile_obs_space.update(
                    robot.get_gripper_jpos(),  # 6
                    robot.get_all_touch_buffer(args.hap_sample))  # 30x12

                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                # Check if done
                if not done and not failure:
                    next_tactile_obs = tactile_obs_space.get_state()
                    normalizer.observe(next_tactile_obs)
                    next_tactile_obs = normalizer.normalize(next_tactile_obs)
                    # Get image and normalize it
                    next_img = robot.get_img(args.img_w, args.img_h, 'c1',
                                             args.depth)
                    if args.depth:
                        next_depth = norm_depth(next_img[1])
                        next_img = norm_img(next_img[0])
                        next_img_norm = np.empty((4, args.img_w, args.img_h))
                        next_img_norm[:3, :, :] = next_img
                        next_img_norm[3, :, :] = next_depth
                    else:
                        next_img_norm = norm_img(next_img)
                    next_state = [next_tactile_obs, next_img_norm]
                else:
                    next_state = None

                # Push new Transition into memory
                localMemory.append(
                    Transition(observation, action, next_state, reward))

                # Optimize the model
                if self.steps_done % 10 == 0:
                    self.optimize()

                # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(velcro_params))
                    print("Actions in this epoch are: {}".format(act_sequence))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(
                    args.output_dir, 'checkpoint_model_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'conv_net_1': self.conv_net_1.state_dict(),
                        'conv_net_2': self.conv_net_2.state_dict(),
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'conv_optimizer_1': self.conv_optimizer_1.state_dict(),
                        'conv_optimizer_2': self.conv_optimizer_2.state_dict(),
                        'policy_optimizer_1':
                        self.policy_optimizer_1.state_dict(),
                        'policy_optimizer_2':
                        self.policy_optimizer_2.state_dict(),
                    }, save_path)

        # Save normalizer state for inference
        normalizer.save_state(
            os.path.join(args.output_dir, 'normalizer_state.pickle'))

        self.memory.save_memory(os.path.join(args.output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables
class POMDP:
    def __init__(self, args):
        self.args = args
        self.ACTIONS = ['left', 'right', 'forward', 'backward', 'up',
                        'down']  # 'open', 'close']
        self.P_START = 0.999
        self.P_END = 0.05
        self.P_DECAY = 500
        self.max_iter = args.max_iter
        self.gripping_force = args.grip_force
        self.break_threshold = args.break_thresh

        # Prepare the drawing figure
        fig, (ax1, ax2) = plt.subplots(1, 2)
        self.figure = (fig, ax1, ax2)

    # Function to select an action from our policy or a random one
    def select_action(self, observation, hidden_state, cell_state):
        args = self.args
        sample = random.random()
        p_threshold = self.P_END + (self.P_START - self.P_END) * math.exp(
            -1. * self.steps_done / self.P_DECAY)

        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            self.tactile_net_1.eval()
            self.policy_net_1.eval()
            torch_obs = torch.from_numpy(observation).float().to(
                args.device).unsqueeze(0)
            h_tac, c_tac = self.tactile_net_1.init_hidden_states(args.device)
            tactile_ft = self.tactile_net_1.forward(torch_obs,
                                                    hidden_state=h_tac,
                                                    cell_state=c_tac)

            model_out = self.policy_net_1(tactile_ft.unsqueeze(1),
                                          batch_size=1,
                                          time_step=1,
                                          hidden_state=hidden_state,
                                          cell_state=cell_state)
            out = model_out[0]
            hidden_state = model_out[1][0]
            cell_state = model_out[1][1]
            self.tactile_net_1.train()
            self.policy_net_1.train()

            if sample > p_threshold:
                action = int(torch.argmax(out[0]))
                return action, hidden_state, cell_state
            else:
                return random.randrange(0,
                                        args.outdim), hidden_state, cell_state

    def sample_memory(self):
        batch = self.memory.sample(self.args.batch_size, self.args.time_step)
        if not batch:
            return

        current_states, actions, rewards, next_states = [], [], [], []
        for b in batch:
            cs, ac, rw, ns = [], [], [], []
            for element in b:
                cs.append(element.state)
                ac.append(element.action)
                rw.append(element.reward)
                ns.append(element.next_state)
            current_states.append(cs)
            actions.append(ac)
            rewards.append(rw)
            next_states.append(ns)
        return current_states, actions, rewards, next_states

    def extract_ft(self, tactile_net, batch):
        args = self.args
        h_tac, c_tac = tactile_net.init_hidden_states(args.device)
        result = []
        for b in batch:
            obs_sequence = []
            for item in b:
                torch_obs = torch.from_numpy(item).float().to(
                    args.device).unsqueeze(0)
                tactile_ft = tactile_net.forward(torch_obs,
                                                 hidden_state=h_tac,
                                                 cell_state=c_tac)
                obs_sequence.append(tactile_ft)
            torch_obs_sequence = torch.stack(obs_sequence)
            result.append(torch_obs_sequence)
        return torch.stack(result)

    def bootstrap(self, policy_net, tactile_net, memory_subset):
        args = self.args
        if len(self.memory) < (args.batch_size):
            return

        current_states, actions, rewards, next_states = memory_subset

        # padded_current_states, current_lengths = self.pad_batch(current_states)
        # padded_next_states, next_lengths = self.pad_batch(next_states)

        # process observation (tactile_obs, img) to 1d tensor of (tactile_obs, feature)
        # and then stack them to corresponding dimension: (batch_size, time_step, *)
        current_features = self.extract_ft(tactile_net, current_states)
        next_features = self.extract_ft(tactile_net, next_states)

        # convert all to torch tensors
        actions = torch.from_numpy(np.array(actions)).long().to(args.device)
        rewards = torch.from_numpy(np.array(rewards)).float().to(args.device)

        hidden_batch, cell_batch = policy_net.init_hidden_states(
            args.batch_size, args.device)

        Q_s, _ = policy_net.forward(current_features,
                                    batch_size=args.batch_size,
                                    time_step=args.time_step,
                                    hidden_state=hidden_batch,
                                    cell_state=cell_batch)
        Q_s_a = Q_s.gather(dim=1,
                           index=actions[:, args.time_step -
                                         1].unsqueeze(dim=1)).squeeze(dim=1)

        Q_next, _ = policy_net.forward(next_features,
                                       batch_size=args.batch_size,
                                       time_step=args.time_step,
                                       hidden_state=hidden_batch,
                                       cell_state=cell_batch)
        Q_next_max = Q_next.detach().max(dim=1)[0]
        return Q_s_a, Q_next_max

    def optimize(self):
        args = self.args
        if len(self.memory) < (args.batch_size):
            return

        memory_subset = self.sample_memory()
        _, _, rewards, _ = memory_subset
        rewards = torch.from_numpy(np.array(rewards)).float().to(args.device)

        Q_s_a_1, Q_next_max_1 = self.bootstrap(self.policy_net_1,
                                               self.tactile_net_1,
                                               memory_subset)
        Q_s_a_2, Q_next_max_2 = self.bootstrap(self.policy_net_2,
                                               self.tactile_net_2,
                                               memory_subset)

        Q_next_max = torch.min(Q_next_max_1, Q_next_max_2)

        # Compute the expected Q values
        target_values = rewards[:,
                                args.time_step - 1] + (args.gamma * Q_next_max)

        # Compute Huber loss
        loss_1 = F.smooth_l1_loss(Q_s_a_1, target_values)
        loss_2 = F.smooth_l1_loss(Q_s_a_2, target_values)

        # Optimize the model
        self.policy_optimizer_1.zero_grad()
        self.policy_optimizer_2.zero_grad()
        self.tactile_optimizer_1.zero_grad()
        self.tactile_optimizer_2.zero_grad()
        loss_1.backward()
        loss_2.backward()
        for param in self.policy_net_1.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.policy_net_2.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.tactile_net_2.parameters():
            param.grad.data.clamp_(-1, 1)
        for param in self.tactile_net_2.parameters():
            param.grad.data.clamp_(-1, 1)

        self.policy_optimizer_1.step()
        self.policy_optimizer_2.step()
        self.tactile_optimizer_1.step()
        self.tactile_optimizer_2.step()

    def train_POMDP(self):
        args = self.args
        ROOT_DIR = os.path.dirname(os.path.dirname(
            os.path.abspath(__file__)))  # corl2019
        PARENT_DIR = os.path.dirname(ROOT_DIR)  # reserach
        # Create the output directory if it does not exist
        output_dir = os.path.join(PARENT_DIR, 'multistep_pomdp',
                                  args.output_dir)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # write args to file
        with open(os.path.join(output_dir, 'args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
        f.close()

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.ftdim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.ftdim, args.outdim).to(args.device)
        if args.position:
            self.tactile_net_1 = TactileNet(args.indim - 6,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim - 6,
                                            args.ftdim).to(args.device)
        elif args.force:
            self.tactile_net_1 = TactileNet(args.indim - 390,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim - 390,
                                            args.ftdim).to(args.device)
        else:
            self.tactile_net_1 = TactileNet(args.indim,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim,
                                            args.ftdim).to(args.device)

        # Set up the optimizer
        self.policy_optimizer_1 = optim.RMSprop(self.policy_net_1.parameters(),
                                                lr=args.lr)
        self.policy_optimizer_2 = optim.RMSprop(self.policy_net_2.parameters(),
                                                lr=args.lr)
        self.tactile_optimizer_1 = optim.RMSprop(
            self.tactile_net_1.parameters(), lr=args.lr)
        self.tactile_optimizer_2 = optim.RMSprop(
            self.tactile_net_2.parameters(), lr=args.lr)
        self.memory = RecurrentMemory(800)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.weight_policy:
            if os.path.exists(args.weight_policy):
                checkpoint = torch.load(args.weight_policy)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.policy_optimizer_1.load_state_dict(
                    checkpoint['policy_optimizer_1'])
                self.policy_optimizer_2.load_state_dict(
                    checkpoint['policy_optimizer_2'])
                start_episode = checkpoint['epochs']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.weight_policy),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        if args.weight_tactile:
            checkpoint = torch.load(args.weight_tactile)
            self.tactile_net_1.load_state_dict(checkpoint['tactile_net_1'])
            self.tactile_optimizer_1.load_state_dict(
                checkpoint['tactile_optimizer_1'])
            self.tactile_net_2.load_state_dict(checkpoint['tactile_net_2'])
            self.tactile_optimizer_2.load_state_dict(
                checkpoint['tactile_optimizer_2'])

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)

        tactile_obs_space = TactileObs(
            robot.get_gripper_xpos(),  # 24
            robot.get_all_touch_buffer(args.hap_sample))  # 30 x 6

        # Main training loop
        for ii in range(start_episode, args.epochs):
            self.steps_done += 1
            start_time = time.time()
            act_sequence = []
            act_length = []
            velcro_params = init_model(robot.mj_sim)
            robot.reset_simulation()
            ret = robot.grasp_handle()
            if not ret:
                continue

            # Local memory for current episode
            localMemory = []

            # Get current observation
            hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                batch_size=1, device=args.device)
            hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                batch_size=1, device=args.device)

            broken_so_far = 0

            # pick a random action initially
            action = random.randrange(0, 5)
            current_state = None
            next_state = None

            t = 0

            while t < args.max_iter:
                if not args.quiet and t == 0:
                    print("Running training episode: {}".format(ii, t))

                if args.position:
                    multistep_obs = np.empty((0, args.indim - 6))
                elif args.force:
                    multistep_obs = np.empty((0, args.indim - 390))
                else:
                    multistep_obs = np.empty((0, args.indim))

                prev_action = action

                for k in range(args.len_ub):
                    # Observe tactile features and stack them
                    tactile_obs = tactile_obs_space.get_state()
                    normalizer.observe(tactile_obs)
                    tactile_obs = normalizer.normalize(tactile_obs)

                    if args.position:
                        tactile_obs = tactile_obs[6:]
                    elif args.force:
                        tactile_obs = tactile_obs[:6]

                    multistep_obs = np.vstack((multistep_obs, tactile_obs))

                    # current jpos
                    current_pos = robot.get_gripper_jpos()[:3]

                    # Perform action
                    delta = action_space.get_action(
                        self.ACTIONS[action])['delta'][:3]
                    target_position = np.add(robot.get_gripper_jpos()[:3],
                                             np.array(delta))
                    target_pose = np.hstack(
                        (target_position, robot.get_gripper_jpos()[3:]))
                    robot.move_joint(target_pose,
                                     True,
                                     self.gripping_force,
                                     hap_sample=args.hap_sample)

                    # Observe new state
                    tactile_obs_space.update(
                        robot.get_gripper_xpos(),  # 24
                        robot.get_all_touch_buffer(args.hap_sample))  # 30x6

                    displacement = la.norm(robot.get_gripper_jpos()[:3] -
                                           current_pos)

                    if displacement / 0.06 < 0.7:
                        break

                # input stiched multi-step tactile observation into tactile-net to generate tactile feature
                action, hidden_state_1, cell_state_1 = self.select_action(
                    multistep_obs, hidden_state_1, cell_state_1)

                if t == 0:
                    next_state = multistep_obs.copy()
                else:
                    current_state = next_state.copy()
                    next_state = multistep_obs.copy()

                # record actions in this epoch
                act_sequence.append(prev_action)
                act_length.append(k)

                # Get reward
                done, num = robot.update_tendons()
                failure = robot.check_slippage()
                if num > broken_so_far:
                    reward = num - broken_so_far
                    broken_so_far = num
                else:
                    if failure:
                        reward = -20
                    else:
                        reward = 0

                t += k + 1
                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                if done or failure:
                    next_state = None

                # Push new Transition into memory
                if t > k + 1:
                    localMemory.append(
                        Transition(current_state, prev_action, next_state,
                                   reward))

                # Optimize the model
                if self.steps_done % 10 == 0:
                    self.optimize()

                # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(velcro_params))
                    print(
                        "{} of Actions in this epoch are: {} \n Action length are: {}"
                        .format(len(act_sequence), act_sequence, act_length))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(output_dir,
                                         'policy_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'policy_optimizer_1':
                        self.policy_optimizer_1.state_dict(),
                        'policy_optimizer_2':
                        self.policy_optimizer_2.state_dict(),
                    }, save_path)
                save_path = os.path.join(output_dir,
                                         'tactile_' + str(ii) + '.pth')
                torch.save(
                    {
                        'tactile_net_1':
                        self.tactile_net_1.state_dict(),
                        'tactile_net_2':
                        self.tactile_net_2.state_dict(),
                        'tactile_optimizer_1':
                        self.tactile_optimizer_1.state_dict(),
                        'tactile_optimizer_2':
                        self.tactile_optimizer_2.state_dict(),
                    }, save_path)

                write_results(os.path.join(output_dir, 'results_pomdp.pkl'),
                              print_variables)

                self.memory.save_memory(
                    os.path.join(output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables