def main(args):
    if not os.path.isdir(args.result_dir):
        os.makedirs(args.result_dir)

    conv_net = ConvNet(args.outdim, args.depth).to(args.device)
    if os.path.exists(args.weight_convnet):
        checkpoint = torch.load(args.weight_convnet)
        conv_net.load_state_dict(checkpoint['conv_net'])

    # Create robot, reset simulation and grasp handle
    model = load_model_from_path(args.model_path)
    sim = MjSim(model)
    sim_param = SimParameter(sim)
    sim.step()
    if args.render:
        viewer = MjViewer(sim)
    else:
        viewer = None
    robot = RobotSim(sim, viewer, sim_param, args.render, args.break_thresh)
    # load all velcro parameters
    # model_dir = os.path.dirname(args.model_path)
    # param_path = os.path.join(model_dir, 'uniform_sample.pkl')
    param_path = '/home/jc/research/corl2019_learningHaptics/tests/test_xmls/case_{}.pickle'.format(
        args.case)
    velcro_params = pickle.load(open(param_path, 'rb'))
    if args.shuffle:
        random.shuffle(velcro_params)

    velcro_util = VelcroUtil(robot, sim_param)
    state_space = Observation(
        robot.get_gripper_jpos(),  # 6
        velcro_util.break_center(),  # 6
        velcro_util.break_norm())
    action_space = ActionSpace(dp=0.06, df=10)
    performance = {
        'time': [],
        'success': [],
        'num_broken': [],
        'tendon_hist': [0, 0, 0, 0, 0]
    }

    for n, item in enumerate(velcro_params):
        geom_type, origin_offset, euler, radius = item
        print('\n\nTest {} Velcro parameters are: {}, {}, {}, {}'.format(
            n, geom_type, origin_offset, euler, radius))
        change_sim(robot.mj_sim, geom_type, origin_offset, euler, radius)
        robot.reset_simulation()
        ret = robot.grasp_handle()
        broken_so_far = 0

        # ax.clear()

        for t in range(args.max_iter):
            # take image an predict norm direction
            # Get image and normalize it
            img = robot.get_img(args.img_w, args.img_h, 'c1', args.depth)
            if args.depth:
                depth = norm_depth(img[1])
                img = norm_img(img[0])
                img_norm = np.empty((4, args.img_w, args.img_h))
                img_norm[:3, :, :] = img
                img_norm[3, :, :] = depth
            else:
                img_norm = norm_img(img)

            torch_img = torch.from_numpy(img_norm).float().to(
                args.device).unsqueeze(0)
            pred = conv_net.forward(torch_img).detach().cpu()
            fl_norm = pred[0][3:6].numpy()
            break_dir_norm = pred[0][6:9].numpy()
            # normalize these vectors
            fl_norm = fl_norm / la.norm(fl_norm)
            break_dir_norm = break_dir_norm / la.norm(break_dir_norm)

            ################ choose action and get action direction vector ################
            action_direction = args.act_mag * (-0.5 * fl_norm +
                                               0.5 * break_dir_norm)
            action_key = (action_vec @ action_direction).argmax()
            action_direction = action_space.get_action(
                ACTIONS[action_key])['delta'][:3]

            gripper_pose = robot.get_gripper_jpos()[:3]

            # Perform action
            target_position = np.add(robot.get_gripper_jpos()[:3],
                                     action_direction)
            target_pose = np.hstack(
                (target_position, robot.get_gripper_jpos()[3:]))
            robot.move_joint(target_pose, True, 300, hap_sample=30)

            # check tendons and slippage
            done, num = robot.update_tendons()
            failure = robot.check_slippage()
            if num > broken_so_far:
                broken_so_far = num

            if done or failure:
                ratio_broken = float(num) / float(NUM_TENDON)
                if ratio_broken < 0.2:
                    performance['tendon_hist'][0] += 1
                elif ratio_broken >= 0.2 and ratio_broken < 0.4:
                    performance['tendon_hist'][1] += 1
                elif ratio_broken >= 0.4 and ratio_broken < 0.6:
                    performance['tendon_hist'][2] += 1
                elif ratio_broken >= 0.6 and ratio_broken < 0.8:
                    performance['tendon_hist'][3] += 1
                else:
                    performance['tendon_hist'][4] += 1
                performance['num_broken'].append(num)
                if done:
                    performance['success'].append(1)
                    performance['time'].append(t + 1)
                if failure:
                    performance['success'].append(0)
                    performance['time'].append(t + 1)
                break

            if t == args.max_iter - 1:

                ################## exceed max iterations ####################
                performance['success'].append(0)
                performance['time'].append(args.max_iter)
                ratio_broken = float(num) / float(NUM_TENDON)
                performance['num_broken'].append(num)
                if ratio_broken < 0.2:
                    performance['tendon_hist'][0] += 1
                elif ratio_broken >= 0.2 and ratio_broken < 0.4:
                    performance['tendon_hist'][1] += 1
                elif ratio_broken >= 0.4 and ratio_broken < 0.6:
                    performance['tendon_hist'][2] += 1
                elif ratio_broken >= 0.6 and ratio_broken < 0.8:
                    performance['tendon_hist'][3] += 1
                else:
                    performance['tendon_hist'][4] += 1

        # print episode performance
        print('Success: {}, time: {}, num_broken: {} '.format(
            performance['success'][-1], performance['time'][-1],
            performance['num_broken'][-1]))

    print('Finished opening velcro with haptics test \n')
    success = np.array(performance['success'])
    time = np.array(performance['time'])
    print('Successfully opened the velcro in: {}% of cases'.format(
        100 * np.sum(success) / len(performance['success'])))
    print('Average time to open: {}'.format(np.average(time[success > 0])))

    out_fname = 'vision_case{}.txt'.format(args.case)
    with open(os.path.join(args.result_dir, out_fname), 'w+') as f:
        f.write('Time: {}\n'.format(performance['time']))
        f.write('Success: {}\n'.format(performance['success']))
        f.write('Successfully opened the velcro in: {}% of cases\n'.format(
            100 * np.sum(success) / len(performance['success'])))
        f.write('Average time to open: {}\n'.format(
            np.average(time[success > 0])))
        f.write('Num_broken: {}\n'.format(performance['num_broken']))
        f.write('Tendon histogram: {}\n'.format(performance['tendon_hist']))
    f.close()
    def train_POMDP(self):
        args = self.args
        # Create the output directory if it does not exist
        if not os.path.isdir(args.output_dir):
            os.makedirs(args.output_dir)

        # Create our policy net and a target net
        self.policy_net_1 = DRQN(args.indim, args.outdim).to(args.device)
        self.policy_net_2 = DRQN(args.indim, args.outdim).to(args.device)
        self.conv_net_1 = ConvNet(args.ftdim, args.depth).to(args.device)
        self.conv_net_2 = ConvNet(args.ftdim, args.depth).to(args.device)

        # Set up the optimizer
        self.policy_optimizer_1 = optim.RMSprop(self.policy_net_1.parameters(),
                                                lr=args.lr)
        self.policy_optimizer_2 = optim.RMSprop(self.policy_net_2.parameters(),
                                                lr=args.lr)
        self.conv_optimizer_1 = optim.RMSprop(self.conv_net_1.parameters(),
                                              lr=1e-5)
        self.conv_optimizer_2 = optim.RMSprop(self.conv_net_2.parameters(),
                                              lr=1e-5)
        self.memory = RecurrentMemory(70)
        self.steps_done = 0

        # Setup the state normalizer
        normalizer = Multimodal_Normalizer(num_inputs=args.indim - args.ftdim,
                                           device=args.device)

        print_variables = {'durations': [], 'rewards': [], 'loss': []}
        start_episode = 0
        if args.checkpoint_file:
            if os.path.exists(args.checkpoint_file):
                checkpoint = torch.load(args.checkpoint_file)
                self.policy_net_1.load_state_dict(checkpoint['policy_net_1'])
                self.policy_net_2.load_state_dict(checkpoint['policy_net_2'])
                self.conv_net_1.load_state_dict(checkpoint['conv_net_1'])
                self.conv_net_2.load_state_dict(checkpoint['conv_net_2'])
                self.policy_optimizer_1.load_state_dict(
                    checkpoint['policy_optimizer_1'])
                self.policy_optimizer_2.load_state_dict(
                    checkpoint['policy_optimizer_2'])
                self.conv_optimizer_1.load_state_dict(
                    checkpoint['conv_optimizer_1'])
                self.conv_optimizer_2.load_state_dict(
                    checkpoint['conv_optimizer_2'])
                start_episode = checkpoint['epoch']
                self.steps_done = checkpoint['steps_done']
                with open(
                        os.path.join(os.path.dirname(args.checkpoint_file),
                                     'results_pomdp.pkl'), 'rb') as file:
                    plot_dict = pickle.load(file)
                    print_variables['durations'] = plot_dict['durations']
                    print_variables['rewards'] = plot_dict['rewards']

        if args.normalizer_file:
            if os.path.exists(args.normalizer_file):
                normalizer.restore_state(args.normalizer_file)

        if args.memory:
            if os.path.exists(args.memory):
                self.memory.load(args.memory)

        if args.weight_conv:
            checkpoint = torch.load(args.weight_conv)
            self.conv_net_1.load_state_dict(checkpoint['conv_net'])
            self.conv_optimizer_1.load_state_dict(checkpoint['conv_optimizer'])
            self.conv_net_2.load_state_dict(checkpoint['conv_net'])
            self.conv_optimizer_2.load_state_dict(checkpoint['conv_optimizer'])

        action_space = ActionSpace(dp=0.06, df=10)

        # Create robot, reset simulation and grasp handle
        model = load_model_from_path(args.model_path)
        sim = MjSim(model)
        sim_param = SimParameter(sim)
        sim.step()
        if args.render:
            viewer = MjViewer(sim)
        else:
            viewer = None

        robot = RobotSim(sim, viewer, sim_param, args.render,
                         self.break_threshold)
        tactile_obs_space = TactileObs(
            robot.get_gripper_jpos(),  # 6
            robot.get_all_touch_buffer(args.hap_sample))  # 30 x 12

        # Main training loop
        for ii in range(start_episode, args.epochs):
            start_time = time.time()
            act_sequence = []
            velcro_params = init_model(robot.mj_sim)
            robot.reset_simulation()
            ret = robot.grasp_handle()
            if not ret:
                continue

            # Local memory for current episode
            localMemory = []

            # Get current observation
            hidden_state_1, cell_state_1 = self.policy_net_1.init_hidden_states(
                batch_size=1, device=args.device)
            hidden_state_2, cell_state_2 = self.policy_net_2.init_hidden_states(
                batch_size=1, device=args.device)

            broken_so_far = 0

            for t in count():
                if not args.quiet and t % 50 == 0:
                    print("Running training episode: {}, iteration: {}".format(
                        ii, t))

                # Select action
                tactile_obs = tactile_obs_space.get_state()
                normalizer.observe(tactile_obs)
                tactile_obs = normalizer.normalize(tactile_obs)
                # Get image and normalize it
                img = robot.get_img(args.img_w, args.img_h, 'c1', args.depth)
                if args.depth:
                    depth = norm_depth(img[1])
                    img = norm_img(img[0])
                    img_norm = np.empty((4, args.img_w, args.img_h))
                    img_norm[:3, :, :] = img
                    img_norm[3, :, :] = depth
                else:
                    img_norm = norm_img(img)

                observation = [tactile_obs, img_norm]
                action, hidden_state_1, cell_state_1 = self.select_action(
                    observation, hidden_state_1, cell_state_1)

                # record actions in this epoch
                act_sequence.append(action)

                # Perform action
                delta = action_space.get_action(
                    self.ACTIONS[action])['delta'][:3]
                target_position = np.add(robot.get_gripper_jpos()[:3],
                                         np.array(delta))
                target_pose = np.hstack(
                    (target_position, robot.get_gripper_jpos()[3:]))
                robot.move_joint(target_pose,
                                 True,
                                 self.gripping_force,
                                 hap_sample=args.hap_sample)

                # Get reward
                done, num = robot.update_tendons()
                failure = robot.check_slippage()
                if num > broken_so_far:
                    reward = num - broken_so_far
                    broken_so_far = num
                else:
                    reward = 0

                # Observe new state
                tactile_obs_space.update(
                    robot.get_gripper_jpos(),  # 6
                    robot.get_all_touch_buffer(args.hap_sample))  # 30x12

                # Set max number of iterations
                if t >= self.max_iter:
                    done = True

                # Check if done
                if not done and not failure:
                    next_tactile_obs = tactile_obs_space.get_state()
                    normalizer.observe(next_tactile_obs)
                    next_tactile_obs = normalizer.normalize(next_tactile_obs)
                    # Get image and normalize it
                    next_img = robot.get_img(args.img_w, args.img_h, 'c1',
                                             args.depth)
                    if args.depth:
                        next_depth = norm_depth(next_img[1])
                        next_img = norm_img(next_img[0])
                        next_img_norm = np.empty((4, args.img_w, args.img_h))
                        next_img_norm[:3, :, :] = next_img
                        next_img_norm[3, :, :] = next_depth
                    else:
                        next_img_norm = norm_img(next_img)
                    next_state = [next_tactile_obs, next_img_norm]
                else:
                    next_state = None

                # Push new Transition into memory
                localMemory.append(
                    Transition(observation, action, next_state, reward))

                # Optimize the model
                if self.steps_done % 10 == 0:
                    self.optimize()

                # If we are done, reset the model
                if done or failure:
                    self.memory.push(localMemory)
                    if failure:
                        print_variables['durations'].append(self.max_iter)
                    else:
                        print_variables['durations'].append(t)
                    print_variables['rewards'].append(broken_so_far)
                    plot_variables(self.figure, print_variables,
                                   "Training POMDP")
                    print("Model parameters: {}".format(velcro_params))
                    print("Actions in this epoch are: {}".format(act_sequence))
                    print("Epoch {} took {}s, total number broken: {}\n\n".
                          format(ii,
                                 time.time() - start_time, broken_so_far))

                    break

            # Save checkpoints every vew iterations
            if ii % args.save_freq == 0:
                save_path = os.path.join(
                    args.output_dir, 'checkpoint_model_' + str(ii) + '.pth')
                torch.save(
                    {
                        'epochs': ii,
                        'steps_done': self.steps_done,
                        'conv_net_1': self.conv_net_1.state_dict(),
                        'conv_net_2': self.conv_net_2.state_dict(),
                        'policy_net_1': self.policy_net_1.state_dict(),
                        'policy_net_2': self.policy_net_2.state_dict(),
                        'conv_optimizer_1': self.conv_optimizer_1.state_dict(),
                        'conv_optimizer_2': self.conv_optimizer_2.state_dict(),
                        'policy_optimizer_1':
                        self.policy_optimizer_1.state_dict(),
                        'policy_optimizer_2':
                        self.policy_optimizer_2.state_dict(),
                    }, save_path)

        # Save normalizer state for inference
        normalizer.save_state(
            os.path.join(args.output_dir, 'normalizer_state.pickle'))

        self.memory.save_memory(os.path.join(args.output_dir, 'memory.pickle'))

        if args.savefig_path:
            now = dt.datetime.now()
            self.figure[0].savefig(
                args.savefig_path +
                '{}_{}_{}.png'.format(now.month, now.day, now.hour),
                format='png')

        print('Training done')
        plt.show()
        return print_variables