def train_POMDP(self):
        args = self.args
        ROOT_DIR = os.path.dirname(os.path.dirname(
            os.path.abspath(__file__)))  # corl2019
        PARENT_DIR = os.path.dirname(ROOT_DIR)  # reserach
        # Create the output directory if it does not exist
        output_dir = os.path.join(PARENT_DIR, 'multistep_pomdp',
                                  args.output_dir)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # write args to file
        with open(os.path.join(output_dir, 'args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
        f.close()

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_3 = DRQN(args.indim, args.outdim).to(args.device)

        # Set up the optimizer
        self.optimizer_1 = optim.RMSprop(self.policy_net_1.parameters())
        self.optimizer_2 = optim.RMSprop(self.policy_net_2.parameters())
        self.optimizer_3 = optim.RMSprop(self.policy_net_3.parameters())
        self.memory = RecurrentMemory(800)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.checkpoint_file:
            if os.path.exists(args.checkpoint_file):
                checkpoint = torch.load(args.checkpoint_file)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.policy_net_3.load_state_dict(checkpoint['policy_net_3'])
                self.optimizer_1.load_state_dict(checkpoint['optimizer_1'])
                self.optimizer_2.load_state_dict(checkpoint['optimizer_2'])
                self.optimizer_3.load_state_dict(checkpoint['optimizer_3'])
                start_episode = checkpoint['epochs']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.checkpoint_file),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)

        # Main training loop
        for ii in range(start_episode, args.epochs):
            start_time = time.time()
            self.steps_done += 1
            act_sequence = []
            if args.sim:
                sim_params = init_model(robot.mj_sim)
                robot.reset_simulation()
                ret = robot.grasp_handle()
                if not ret:
                    continue

                # Local memory for current episode
                localMemory = []

                # Get current observation
                hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                    batch_size=1, device=args.device)
                hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                    batch_size=1, device=args.device)
                hidden_state_3, cell_state_3 = self.policy_net_3.init_hidden_states(
                    batch_size=1, device=args.device)
                observation_space = TactileObs(
                    robot.get_gripper_xpos(),  # 24
                    robot.get_all_touch_buffer(args.hap_sample))  # 30 x 7
                broken_so_far = 0

            for t in count():
                if not args.quiet and t % 50 == 0:
                    print("Running training episode: {}, iteration: {}".format(
                        ii, t))

                # Select action
                observation = observation_space.get_state()
                if args.position:
                    observation = observation[6:]
                if args.shear:
                    indices = np.ones(len(observation), dtype=bool)
                    indices[6:166] = False
                    observation = observation[indices]
                if args.force:
                    observation = observation[:166]
                normalizer.observe(observation)
                observation = normalizer.normalize(observation)
                action, hidden_state_1, cell_state_1 = self.select_action(
                    observation, hidden_state_1, cell_state_1)

                # record actions in this epoch
                act_sequence.append(action)

                # Perform action
                delta = action_space.get_action(
                    self.ACTIONS[action])['delta'][:3]
                target_position = np.add(robot.get_gripper_jpos()[:3],
                                         np.array(delta))
                target_pose = np.hstack(
                    (target_position, robot.get_gripper_jpos()[3:]))

                if args.sim:
                    robot.move_joint(target_pose,
                                     True,
                                     self.gripping_force,
                                     hap_sample=args.hap_sample)

                    # Get reward
                    done, num = robot.update_tendons()
                    failure = robot.check_slippage()
                    if num > broken_so_far:
                        reward = num - broken_so_far
                        broken_so_far = num
                    else:
                        reward = 0

                    # # Add a movement reward
                    # reward -= 0.05 * np.linalg.norm(target_position - robot.get_gripper_jpos()[:3]) / np.linalg.norm(delta)

                    # Observe new state
                    observation_space.update(
                        robot.get_gripper_xpos(),  # 24
                        robot.get_all_touch_buffer(args.hap_sample))  # 30x7

                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                # Check if done
                if not done and not failure:
                    next_state = observation_space.get_state()
                    if args.position:
                        next_state = next_state[6:]
                    if args.shear:
                        indices = np.ones(len(next_state), dtype=bool)
                        indices[6:166] = False
                        next_state = next_state[indices]
                    if args.force:
                        next_state = next_state[:166]
                    normalizer.observe(next_state)
                    next_state = normalizer.normalize(next_state)
                else:
                    next_state = None

                # Push new Transition into memory
                localMemory.append(
                    Transition(observation, action, next_state, reward))

                # Optimize the model
                if t % 10 == 0:
                    loss = self.optimize_model()
        #        if loss:
        #            print_variables['loss'].append(loss.item())

        # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(sim_params))
                    print("Actions in this epoch are: {}".format(act_sequence))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(
                    output_dir, 'checkpoint_model_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'policy_net_3': self.policy_net_3.state_dict(),
                        'optimizer_1': self.optimizer_1.state_dict(),
                        'optimizer_2': self.optimizer_2.state_dict(),
                        'optimizer_3': self.optimizer_3.state_dict(),
                    }, save_path)

            self.memory.save_memory(os.path.join(output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables
def main(args):
    if not os.path.isdir(args.result_dir):
        os.makedirs(args.result_dir)
    parent = os.path.dirname(os.path.abspath(__file__))
    # load test xml files
    test_file = os.path.join(parent, 'tests/test_xmls/temp_1_{}.pickle'.format(args.case))
    params = pickle.load(open(test_file, 'rb'))
    # params = params[:6]
    if args.shuffle:
        random.shuffle(params)

    num_test = len(params)
    print('                    ++++++++++++++++++++++++++')
    print('                    +++ Now running case {} +++'.format(args.case))
    print('                    ++++++++++++++++++++++++++\n\n')

    policy_net = DRQN(args.indim, args.outdim)
    policy_net.load_state_dict(torch.load(args.weight_path)['policy_net_1'])
    policy_net.eval()

    # load normalizer
    # Setup the state normalizer
    normalizer = Multimodal_Normalizer(num_inputs = args.indim, device=args.device)
    if args.normalizer_file:
        if os.path.exists(args.normalizer_file):
            normalizer.restore_state(args.normalizer_file)

    # Create robot, reset simulation and grasp handle
    model = load_model_from_path(args.model_path)
    sim = MjSim(model)
    sim_param = SimParameter(sim)
    sim.step()
    if args.render:
        viewer = MjViewer(sim)
    else:
        viewer = None

    robot = RobotSim(sim, viewer, sim_param, args.render, args.break_thresh)

    tactile_obs_space = TactileObs(robot.get_gripper_xpos(),            # 24
                         robot.get_all_touch_buffer(args.hap_sample))   # 30 x 6

    performance = {'time':[], 'success':[], 'num_broken':[], 'tendon_hist':[0,0,0,0,0], 'collision':[], 
                    'action_hist': [0,0,0,0,0,0]}
    
    for i in range(num_test):
        velcro_params = params[i]
        geom, origin_offset, euler, radius = velcro_params
        print('\n\nTest {} Velcro parameters are: {}, {}, {}, {}'.format(i, geom, origin_offset, euler, radius))
        change_sim(robot.mj_sim, geom, origin_offset, euler, radius)
        robot.reset_simulation()
        ret = robot.grasp_handle()
        performance = test_network(args, policy_net, normalizer, robot, tactile_obs_space, performance)
        print('Success: {}, time: {}, num_broken: {}, collision:{} '.format(
                performance['success'][-1], performance['time'][-1], performance['num_broken'][-1], performance['collision'][-1]))

    print('Finished opening velcro with haptics test \n')
    success = np.array(performance['success'])
    time = np.array(performance['time'])
    print('Successfully opened the velcro in: {}% of cases'.format(100 * np.sum(success) / len(performance['success'])))
    print('Average time to open: {}'.format(np.average(time[success>0])))
    print('Action histogram for the test is: {}'.format(performance['action_hist']))

    out_fname = 'case{}.txt'.format(args.case)
    with open(os.path.join(args.result_dir, out_fname), 'w+') as f:
        f.write('Time: {}\n'.format(performance['time']))
        f.write('Success: {}\n'.format(performance['success']))
        f.write('Successfully opened the velcro in: {}% of cases\n'.format(100 * np.sum(success) / len(performance['success'])))
        f.write('Average time to open: {}\n'.format(np.average(time[success>0])))
        f.write('Num_broken: {}\n'.format(performance['num_broken']))
        f.write('Tendon histogram: {}\n'.format(performance['tendon_hist']))
        f.write('collision: {}\n'.format(performance['collision']))
        # f.write('high_success: {} low_success: {} '.format(high_success, low_success))
    f.close()
    def train_POMDP(self):
        args = self.args
        ROOT_DIR = os.path.dirname(os.path.dirname(
            os.path.abspath(__file__)))  # corl2019
        PARENT_DIR = os.path.dirname(ROOT_DIR)  # reserach
        # Create the output directory if it does not exist
        output_dir = os.path.join(PARENT_DIR, 'multistep_pomdp',
                                  args.output_dir)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        # write args to file
        with open(os.path.join(output_dir, 'args.txt'), 'w+') as f:
            json.dump(args.__dict__, f, indent=2)
        f.close()

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.ftdim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.ftdim, args.outdim).to(args.device)
        if args.position:
            self.tactile_net_1 = TactileNet(args.indim - 6,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim - 6,
                                            args.ftdim).to(args.device)
        elif args.force:
            self.tactile_net_1 = TactileNet(args.indim - 390,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim - 390,
                                            args.ftdim).to(args.device)
        else:
            self.tactile_net_1 = TactileNet(args.indim,
                                            args.ftdim).to(args.device)
            self.tactile_net_2 = TactileNet(args.indim,
                                            args.ftdim).to(args.device)

        # Set up the optimizer
        self.policy_optimizer_1 = optim.RMSprop(self.policy_net_1.parameters(),
                                                lr=args.lr)
        self.policy_optimizer_2 = optim.RMSprop(self.policy_net_2.parameters(),
                                                lr=args.lr)
        self.tactile_optimizer_1 = optim.RMSprop(
            self.tactile_net_1.parameters(), lr=args.lr)
        self.tactile_optimizer_2 = optim.RMSprop(
            self.tactile_net_2.parameters(), lr=args.lr)
        self.memory = RecurrentMemory(800)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.weight_policy:
            if os.path.exists(args.weight_policy):
                checkpoint = torch.load(args.weight_policy)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.policy_optimizer_1.load_state_dict(
                    checkpoint['policy_optimizer_1'])
                self.policy_optimizer_2.load_state_dict(
                    checkpoint['policy_optimizer_2'])
                start_episode = checkpoint['epochs']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.weight_policy),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        if args.weight_tactile:
            checkpoint = torch.load(args.weight_tactile)
            self.tactile_net_1.load_state_dict(checkpoint['tactile_net_1'])
            self.tactile_optimizer_1.load_state_dict(
                checkpoint['tactile_optimizer_1'])
            self.tactile_net_2.load_state_dict(checkpoint['tactile_net_2'])
            self.tactile_optimizer_2.load_state_dict(
                checkpoint['tactile_optimizer_2'])

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)

        tactile_obs_space = TactileObs(
            robot.get_gripper_xpos(),  # 24
            robot.get_all_touch_buffer(args.hap_sample))  # 30 x 6

        # Main training loop
        for ii in range(start_episode, args.epochs):
            self.steps_done += 1
            start_time = time.time()
            act_sequence = []
            act_length = []
            velcro_params = init_model(robot.mj_sim)
            robot.reset_simulation()
            ret = robot.grasp_handle()
            if not ret:
                continue

            # Local memory for current episode
            localMemory = []

            # Get current observation
            hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                batch_size=1, device=args.device)
            hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                batch_size=1, device=args.device)

            broken_so_far = 0

            # pick a random action initially
            action = random.randrange(0, 5)
            current_state = None
            next_state = None

            t = 0

            while t < args.max_iter:
                if not args.quiet and t == 0:
                    print("Running training episode: {}".format(ii, t))

                if args.position:
                    multistep_obs = np.empty((0, args.indim - 6))
                elif args.force:
                    multistep_obs = np.empty((0, args.indim - 390))
                else:
                    multistep_obs = np.empty((0, args.indim))

                prev_action = action

                for k in range(args.len_ub):
                    # Observe tactile features and stack them
                    tactile_obs = tactile_obs_space.get_state()
                    normalizer.observe(tactile_obs)
                    tactile_obs = normalizer.normalize(tactile_obs)

                    if args.position:
                        tactile_obs = tactile_obs[6:]
                    elif args.force:
                        tactile_obs = tactile_obs[:6]

                    multistep_obs = np.vstack((multistep_obs, tactile_obs))

                    # current jpos
                    current_pos = robot.get_gripper_jpos()[:3]

                    # Perform action
                    delta = action_space.get_action(
                        self.ACTIONS[action])['delta'][:3]
                    target_position = np.add(robot.get_gripper_jpos()[:3],
                                             np.array(delta))
                    target_pose = np.hstack(
                        (target_position, robot.get_gripper_jpos()[3:]))
                    robot.move_joint(target_pose,
                                     True,
                                     self.gripping_force,
                                     hap_sample=args.hap_sample)

                    # Observe new state
                    tactile_obs_space.update(
                        robot.get_gripper_xpos(),  # 24
                        robot.get_all_touch_buffer(args.hap_sample))  # 30x6

                    displacement = la.norm(robot.get_gripper_jpos()[:3] -
                                           current_pos)

                    if displacement / 0.06 < 0.7:
                        break

                # input stiched multi-step tactile observation into tactile-net to generate tactile feature
                action, hidden_state_1, cell_state_1 = self.select_action(
                    multistep_obs, hidden_state_1, cell_state_1)

                if t == 0:
                    next_state = multistep_obs.copy()
                else:
                    current_state = next_state.copy()
                    next_state = multistep_obs.copy()

                # record actions in this epoch
                act_sequence.append(prev_action)
                act_length.append(k)

                # Get reward
                done, num = robot.update_tendons()
                failure = robot.check_slippage()
                if num > broken_so_far:
                    reward = num - broken_so_far
                    broken_so_far = num
                else:
                    if failure:
                        reward = -20
                    else:
                        reward = 0

                t += k + 1
                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                if done or failure:
                    next_state = None

                # Push new Transition into memory
                if t > k + 1:
                    localMemory.append(
                        Transition(current_state, prev_action, next_state,
                                   reward))

                # Optimize the model
                if self.steps_done % 10 == 0:
                    self.optimize()

                # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(velcro_params))
                    print(
                        "{} of Actions in this epoch are: {} \n Action length are: {}"
                        .format(len(act_sequence), act_sequence, act_length))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(output_dir,
                                         'policy_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'policy_optimizer_1':
                        self.policy_optimizer_1.state_dict(),
                        'policy_optimizer_2':
                        self.policy_optimizer_2.state_dict(),
                    }, save_path)
                save_path = os.path.join(output_dir,
                                         'tactile_' + str(ii) + '.pth')
                torch.save(
                    {
                        'tactile_net_1':
                        self.tactile_net_1.state_dict(),
                        'tactile_net_2':
                        self.tactile_net_2.state_dict(),
                        'tactile_optimizer_1':
                        self.tactile_optimizer_1.state_dict(),
                        'tactile_optimizer_2':
                        self.tactile_optimizer_2.state_dict(),
                    }, save_path)

                write_results(os.path.join(output_dir, 'results_pomdp.pkl'),
                              print_variables)

                self.memory.save_memory(
                    os.path.join(output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables