Example #1
0
 def _run(self, env_kwargs):
     # this is to control the number of CPUs that torch is allowed to use.
     # By default it will use all CPUs, even with GPU acceleration
     th.set_num_threads(1)
     self.model = loadSRLModel(env_kwargs.get("srl_model_path", None), th.cuda.is_available(), self.state_dim,
                               env_object=None)
     # run until the end of the caller thread
     while True:
         # pop an item, get state, and return to sender.
         env_id, var = self.pipe[0].get()
         self.pipe[1][env_id].put(self.model.getState(var, env_id=env_id))
def env_thread(args, thread_num, partition=True):
    """
    Run a session of an environment
    :param args: (ArgumentParser object)
    :param thread_num: (int) The thread ID of the environment session
    :param partition: (bool) If the output should be in multiple parts (default=True)
    """
    env_kwargs = {
        "max_distance": args.max_distance,
        "random_target": args.random_target,
        "force_down": True,
        "is_discrete": not args.continuous_actions,
        "renders": thread_num == 0 and args.display,
        "record_data": not args.no_record_data,
        "multi_view": args.multi_view,
        "save_path": args.save_path,
        "shape_reward": args.shape_reward,
        "simple_continual_target": args.simple_continual,
        "circular_continual_move": args.circular_continual,
        "square_continual_move": args.square_continual,
        "short_episodes": args.short_episodes
    }

    if partition:
        env_kwargs["name"] = args.name + "_part-" + str(thread_num)
    else:
        env_kwargs["name"] = args.name

    load_path, train_args, algo_name, algo_class = None, None, None, None
    model = None
    srl_model = None
    srl_state_dim = 0
    generated_obs = None
    env_norm = None

    if args.run_policy in ["walker", "custom"]:
        if args.latest:
            args.log_dir = latestPath(args.log_custom_policy)
        else:
            args.log_dir = args.log_custom_policy
        args.render = args.display
        args.plotting, args.action_proba = False, False

        train_args, load_path, algo_name, algo_class, _, env_kwargs_extra = loadConfigAndSetup(
            args)
        env_kwargs["srl_model"] = env_kwargs_extra["srl_model"]
        env_kwargs["random_target"] = env_kwargs_extra.get(
            "random_target", False)
        env_kwargs["use_srl"] = env_kwargs_extra.get("use_srl", False)

        # TODO REFACTOR
        env_kwargs["simple_continual_target"] = env_kwargs_extra.get(
            "simple_continual_target", False)
        env_kwargs["circular_continual_move"] = env_kwargs_extra.get(
            "circular_continual_move", False)
        env_kwargs["square_continual_move"] = env_kwargs_extra.get(
            "square_continual_move", False)
        env_kwargs["eight_continual_move"] = env_kwargs_extra.get(
            "eight_continual_move", False)

        eps = 0.2
        env_kwargs["state_init_override"] = np.array([MIN_X + eps, MAX_X - eps]) \
            if args.run_policy == 'walker' else None
        if env_kwargs["use_srl"]:
            env_kwargs["srl_model_path"] = env_kwargs_extra.get(
                "srl_model_path", None)
            env_kwargs["state_dim"] = getSRLDim(
                env_kwargs_extra.get("srl_model_path", None))
            srl_model = MultiprocessSRLModel(num_cpu=args.num_cpu,
                                             env_id=args.env,
                                             env_kwargs=env_kwargs)
            env_kwargs["srl_pipe"] = srl_model.pipe

    env_class = registered_env[args.env][0]
    env = env_class(**env_kwargs)

    if env_kwargs.get('srl_model', None) not in ["raw_pixels", None]:
        # TODO: Remove env duplication
        # This is a dirty trick to normalize the obs.
        # So for as we override SRL environment functions (step, reset) for on-policy generation & generative replay
        # using stable-baselines' normalisation wrappers (step & reset) breaks...
        env_norm = [
            makeEnv(args.env,
                    args.seed,
                    i,
                    args.log_dir,
                    allow_early_resets=False,
                    env_kwargs=env_kwargs) for i in range(args.num_cpu)
        ]
        env_norm = DummyVecEnv(env_norm)
        env_norm = VecNormalize(env_norm, norm_obs=True, norm_reward=False)
        env_norm = loadRunningAverage(
            env_norm, load_path_normalise=args.log_custom_policy)
    using_real_omnibot = args.env == "OmnirobotEnv-v0" and USING_OMNIROBOT

    walker_path = None
    action_walker = None
    state_init_for_walker = None
    kwargs_reset, kwargs_step = {}, {}

    if args.run_policy in ['custom', 'ppo2', 'walker']:
        # Additional env when using a trained agent to generate data
        train_env = vecEnv(env_kwargs, env_class)

        if args.run_policy == 'ppo2':
            model = PPO2(CnnPolicy, train_env).learn(args.ppo2_timesteps)
        else:
            _, _, algo_args = createEnv(args, train_args, algo_name,
                                        algo_class, env_kwargs)
            tf.reset_default_graph()
            set_global_seeds(args.seed % 2 ^ 32)
            printYellow("Compiling Policy function....")
            model = algo_class.load(load_path, args=algo_args)
            if args.run_policy == 'walker':
                walker_path = walkerPath()

    if len(args.replay_generative_model) > 0:
        srl_model = loadSRLModel(args.log_generative_model,
                                 th.cuda.is_available())
        srl_state_dim = srl_model.state_dim
        srl_model = srl_model.model.model

    frames = 0
    start_time = time.time()

    # divide evenly, then do an extra one for only some of them in order to get the right count
    for i_episode in range(args.num_episode // args.num_cpu + 1 *
                           (args.num_episode % args.num_cpu > thread_num)):

        # seed + position in this slice + size of slice (with reminder if uneven partitions)
        seed = args.seed + i_episode + args.num_episode // args.num_cpu * thread_num + \
               (thread_num if thread_num <= args.num_episode % args.num_cpu else args.num_episode % args.num_cpu)
        seed = seed % 2 ^ 32
        if not (args.run_policy in ['custom', 'walker']):
            env.seed(seed)
            env.action_space.seed(
                seed)  # this is for the sample() function from gym.space

        if len(args.replay_generative_model) > 0:

            sample = Variable(th.randn(1, srl_state_dim))
            if th.cuda.is_available():
                sample = sample.cuda()

            generated_obs = srl_model.decode(sample)
            generated_obs = generated_obs[0].detach().cpu().numpy()
            generated_obs = deNormalize(generated_obs)

            kwargs_reset['generated_observation'] = generated_obs
        obs = env.reset(**kwargs_reset)
        done = False
        action_proba = None
        t = 0
        episode_toward_target_on = False

        while not done:

            env.render()

            # Policy to run on the fly - to be trained before generation
            if args.run_policy == 'ppo2':
                action, _ = model.predict([obs])

            # Custom pre-trained Policy (SRL or End-to-End)
            elif args.run_policy in ['custom', 'walker']:
                obs = env_norm._normalize_observation(obs)
                action = [model.getAction(obs, done)]
                action_proba = model.getActionProba(obs, done)
                if args.run_policy == 'walker':
                    action_walker = np.array(walker_path[t])
            # Random Policy
            else:
                # Using a target reaching policy (untrained, from camera) when collecting data from real OmniRobot
                if episode_toward_target_on and np.random.rand() < args.toward_target_timesteps_proportion and \
                        using_real_omnibot:
                    action = [env.actionPolicyTowardTarget()]
                else:
                    action = [env.action_space.sample()]

            # Generative replay +/- for on-policy action
            if len(args.replay_generative_model) > 0:

                if args.run_policy == 'custom':
                    obs = obs.reshape(1, srl_state_dim)
                    obs = th.from_numpy(obs.astype(np.float32)).cuda()
                    z = obs
                    generated_obs = srl_model.decode(z)
                else:
                    sample = Variable(th.randn(1, srl_state_dim))

                    if th.cuda.is_available():
                        sample = sample.cuda()

                    generated_obs = srl_model.decode(sample)
                generated_obs = generated_obs[0].detach().cpu().numpy()
                generated_obs = deNormalize(generated_obs)

            action_to_step = action[0]
            kwargs_step = {
                k: v
                for (k, v) in [("generated_observation",
                                generated_obs), ("action_proba", action_proba),
                               ("action_grid_walker", action_walker)]
                if v is not None
            }

            obs, _, done, _ = env.step(action_to_step, **kwargs_step)

            frames += 1
            t += 1
            if done:
                if np.random.rand(
                ) < args.toward_target_timesteps_proportion and using_real_omnibot:
                    episode_toward_target_on = True
                else:
                    episode_toward_target_on = False
                print("Episode finished after {} timesteps".format(t + 1))

        if thread_num == 0:
            print("{:.2f} FPS".format(frames * args.num_cpu /
                                      (time.time() - start_time)))
def main():
    parser = argparse.ArgumentParser(
        description='Deteministic dataset generator for SRL training ' +
        '(can be used for environment testing)')
    parser.add_argument(
        '--save-path',
        type=str,
        default='data/',
        help='Folder where the environments will save the output')
    parser.add_argument('--name',
                        type=str,
                        default='generated_reaching_on_policy',
                        help='Folder name for the output')
    parser.add_argument('--no-record-data', action='store_true', default=False)
    parser.add_argument(
        '--log-custom-policy',
        type=str,
        default='',
        help='Logs of the custom pretained policy to run for data collection')
    parser.add_argument(
        '--log-generative-model',
        type=str,
        default='',
        help='Logs of the custom pretained policy to run for data collection')
    parser.add_argument(
        '--short-episodes',
        action='store_true',
        default=False,
        help='Generate short episodes (only 2 contacts with the target allowed).'
    )  # we can change the number of contact in omnirobot_env.py
    parser.add_argument('--display', action='store_true', default=False)
    parser.add_argument(
        '--ngsa',
        '--num-generating-samples-per-action',
        type=int,
        default='2000',
        help='The number of generated observation for each of the 4 class')
    parser.add_argument(
        '--shape-reward',
        action='store_true',
        default=False,
        help='Shape the reward (reward = - distance) instead of a sparse reward'
    )
    parser.add_argument('--seed', type=int, default=0, help='the seed')
    parser.add_argument('--task',
                        type=str,
                        default=None,
                        choices=['sc', 'cc'],
                        help='choose task for data set generation')
    parser.add_argument('--grid-walker',
                        action='store_true',
                        default=False,
                        help='Generate the robot as grid walker.')
    parser.add_argument('--gw-step',
                        type=int,
                        default=0.1,
                        help='the grid walker step')
    parser.add_argument('--gw-episode',
                        type=int,
                        default=100,
                        help='number of episode in the grid walker ')

    args = parser.parse_args()

    # assert
    assert not (args.log_generative_model == '' and args.replay_generative_model == 'custom'), \
        "If using a custom policy, please specify a valid log folder for loading it."
    assert not (args.task is None), "must choose a task"

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    # File exists, need to deal with it
    assert not os.path.exists(
        args.save_path +
        args.name), "Error: save directory '{}' already exists".format(
            args.save_path + args.name)
    os.mkdir(args.save_path + args.name)

    print("Using seed = ", args.seed)
    np.random.seed(args.seed)

    # load generative model
    generative_model, gernerative_model_losses, only_action = loadSRLModel(
        args.log_generative_model, th.cuda.is_available())
    generative_model_state_dim = generative_model.state_dim
    generative_model = generative_model.model.model

    # load configurations of rl model
    args.log_dir = args.log_custom_policy
    args.render = args.display
    args.simple_continual, args.circular_continual, args.square_continual = False, False, False
    args.num_cpu = 1

    train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs_extra = loadConfigAndSetup(
        args)

    # assert that the RL was not trained with ground_truth, since the generative model dont
    # generate ground_truth to forward to RL policy model that take input as ground truth
    assert not (train_args['srl_model'] == "ground_truth"
                ), "can not use RL model trained with ground_truth"

    # load rl model
    model = algo_class.load(load_path)

    # check if the rl model was trained with SRL
    if srl_model_path != None:
        srl_model, _, _ = loadSRLModel(srl_model_path, th.cuda.is_available())
        srl_model = srl_model.model.model
    # some condtion
    using_conditional_model = "cgan" in gernerative_model_losses or "cvae" in gernerative_model_losses or "cgan_new" in gernerative_model_losses
    use_cvae_new = "cvae_new" in gernerative_model_losses

    # generate equal numbers of each action (decrete actions for 4 movement)
    if not args.grid_walker and not use_cvae_new:
        actions_0 = 0 * np.ones(args.ngsa)
        actions_1 = 1 * np.ones(args.ngsa)
        actions_2 = 2 * np.ones(args.ngsa)
        actions_3 = 3 * np.ones(args.ngsa)
        actions = np.concatenate(
            (actions_0, actions_1, actions_2, actions_3)).astype(int)
        np.random.seed(args.seed)
        np.random.shuffle(actions)
    else:
        actions = th.zeros(1)

    # create minibatchlist and grid list and target's position list
    grid_list = []
    target_pos_list = []
    if not args.grid_walker:
        minibatchlist = DataLoaderConditional.createTestMinibatchList(
            args.ngsa * 4, MAX_BATCH_SIZE_GPU)
    else:
        grid_walker_number = (int((MAX_X - MIN_X) / args.gw_step) - 1) * (int(
            (MAX_Y - MIN_Y) / args.gw_step) - 1)
        minibatchlist = DataLoaderConditional.createTestMinibatchList(
            grid_walker_number * args.gw_episode, MAX_BATCH_SIZE_GPU)
        # create grid list for grid walker
        for _ in range(args.gw_episode):
            for i in range(int((MAX_X - MIN_X) / args.gw_step) - 1):
                for j in range(int((MAX_Y - MIN_Y) / args.gw_step) - 1):
                    x = MAX_X - (i + 1) * args.gw_step
                    y = MAX_Y - (j + 1) * args.gw_step
                    grid_list.append([x, y])
        # create target list for each episode
        for _ in range(args.gw_episode):
            random_init_x = np.random.random_sample(1).item() * (TARGET_MAX_X - TARGET_MIN_X) + \
                        TARGET_MIN_X if args.task == 'sc' else 0
            random_init_y = np.random.random_sample(1).item() * (TARGET_MAX_Y - TARGET_MIN_Y) + \
                        TARGET_MIN_Y if args.task == 'sc' else 0
            for _ in range(int((MAX_X - MIN_X) / args.gw_step) - 1):
                for _ in range(int((MAX_Y - MIN_Y) / args.gw_step) - 1):
                    target_pos_list.append([random_init_x, random_init_y])
    grid_list = np.asarray(grid_list)
    target_pos_list = np.asarray(target_pos_list)

    # data_loader
    data_loader = DataLoaderConditional(minibatchlist,
                                        actions,
                                        args.task,
                                        generative_model_state_dim,
                                        TARGET_MAX_X,
                                        TARGET_MIN_X,
                                        TARGET_MAX_Y,
                                        TARGET_MIN_Y,
                                        MAX_X,
                                        MIN_X,
                                        MAX_Y,
                                        MIN_Y,
                                        args.grid_walker,
                                        grid_list,
                                        target_pos_list,
                                        seed=args.seed,
                                        max_queue_len=4)

    # some lists for saving at the end
    imgs_paths_array = []
    actions_array = []
    episode_starts = []

    #number of correct class prediction
    num_correct_class = np.zeros(4)
    pbar = tqdm(total=len(minibatchlist))
    for minibatch_num, (z, c, t, r) in enumerate(data_loader):
        if th.cuda.is_available():
            state = z.to('cuda')
            action = c.to('cuda')
            target = t.to('cuda')
            robot_pos = r.to('cuda')
        if using_conditional_model:
            generated_obs = generative_model.decode(state, action, target,
                                                    only_action)
        elif use_cvae_new:
            generated_obs = generative_model.decode(state, target, robot_pos)
        else:
            generated_obs = generative_model.decode(state)

        # save generated obervation
        # [TODO]: even thought we name it "record" but it does not yet seperates images between episodes, we just save every minibatch_num
        folder_path = os.path.join(args.save_path + args.name +
                                   "record_{:03d}/".format(minibatch_num))
        os.mkdir(folder_path)

        # Append the list of image's path and it's coressponding class (action)
        for i in range(generated_obs.size(0)):

            obs = deNormalize(generated_obs[i].to(
                th.device('cpu')).detach().numpy())
            obs = 255 * obs[..., ::-1]
            if using_conditional_model:
                if only_action:
                    imgs_paths = folder_path + "frame_{:06d}_class_{}.jpg".format(
                        i, int(c[i]))
                else:
                    imgs_paths = folder_path + "frame_{:06d}_class_{}_tp_{:.2f}_{:.2f}.jpg".format(
                        i, int(c[i]), t[i][0], t[i][1])
            elif use_cvae_new:
                imgs_paths = folder_path + "frame_{:06d}_tp_{:.2f}_{:.2f}_rp_{:.2f}_{:.2f}.jpg".format(
                    i, t[i][0], t[i][1], r[i][0], r[i][1])
            else:
                imgs_paths = folder_path + "frame_{:06d}.jpg".format(i)

            cv2.imwrite(imgs_paths, obs.astype(np.uint8))
            imgs_paths_array.append(imgs_paths)
            if i == 0 and minibatch_num == 0:
                episode_starts.append(True)
            else:
                episode_starts.append(False)

        if srl_model_path != None:
            generated_obs_state = srl_model.forward(generated_obs.cuda()).to(
                th.device('cpu')).detach().numpy()
            on_policy_actions = model.getAction(generated_obs_state)
            actions_proba = model.getActionProba(generated_obs_state)
        else:
            on_policy_actions = model.getAction(
                generated_obs.to(th.device('cpu')).detach().numpy())
            actions_proba = model.getActionProba(
                generated_obs.to(th.device('cpu')).detach().numpy())

        if minibatch_num == 0:
            actions_proba_array = actions_proba
            on_policy_actions_array = on_policy_actions
            z_array = z.detach().numpy()

        else:
            actions_proba_array = np.append(actions_proba_array,
                                            actions_proba,
                                            axis=0)
            on_policy_actions_array = np.append(on_policy_actions_array,
                                                on_policy_actions,
                                                axis=0)
            z_array = np.append(z_array, z.detach().numpy(), axis=0)

        # count the correct predection
        # used to evaluate the accuracy of the generative model

        if using_conditional_model:
            for i in np.arange(on_policy_actions.shape[0]):
                if c.numpy()[i] == on_policy_actions[i]:
                    if c[i] == 0: num_correct_class[0] += 1
                    elif c[i] == 1: num_correct_class[1] += 1
                    elif c[i] == 2: num_correct_class[2] += 1
                    else: num_correct_class[3] += 1
        pbar.update(1)
    pbar.close()

    if using_conditional_model:
        correct_observations = (100 / args.ngsa) * num_correct_class
        print("The generative model is {}% accurate for {} testing samples.".
              format(
                  np.sum(num_correct_class) * 100 / (args.ngsa * 4),
                  args.ngsa * 4))
        print("Correct observations of action class '0' : {}%".format(
            correct_observations[0]))
        print("Correct observations of action class '1' : {}%".format(
            correct_observations[1]))
        print("Correct observations of action class '2' : {}%".format(
            correct_observations[2]))
        print("Correct observations of action class '3' : {}%".format(
            correct_observations[3]))

    # save some data
    # We dont have any information about the ground_truth_states and target_positions.
    # They are saved for the sake of not causing error in data merging only,since data merging looks also to merge these two arrays.

    np.savez(args.save_path + args.name + "/preprocessed_data.npz",
             actions=on_policy_actions_array.tolist(),
             actions_proba=actions_proba_array.tolist(),
             episode_starts=episode_starts,
             rewards=[],
             z_array=z_array.tolist())
    np.savez(args.save_path + args.name + "/ground_truth.npz",
             images_path=imgs_paths_array,
             ground_truth_states=[[]],
             target_positions=[[]])

    #save configs files
    copyfile(args.log_dir + "/env_globals.json",
             args.save_path + args.name + '/env_globals.json')
    with open(args.save_path + args.name + '/dataset_config.json', 'w') as f:
        json.dump({"img_shape": train_args.get("img_shape", None)}, f)
    if using_conditional_model:
        with open(args.save_path + args.name + '/class_eval.json', 'w') as f:
            json.dump(
                {
                    "num_correct_class": correct_observations.tolist(),
                    "ngsa_per_class": args.ngsa,
                    "random_seed": args.seed
                }, f)
    else:
        with open(args.save_path + args.name + '/class_eval.json', 'w') as f:
            json.dump({"random_seed": args.seed}, f)
Example #4
0
    def train(self, args, callback, env_kwargs=None, train_kwargs=None):

        N_EPOCHS = args.epochs_distillation
        self.seed = args.seed
        self.batch_size = BATCH_SIZE
        print("We assumed SRL training already done")

        print('Loading data for distillation ')
        # training_data, ground_truth, true_states, _ = loadData(args.teacher_data_folder)
        training_data, ground_truth, _, _ = loadData(args.teacher_data_folder, with_env=False)

        images_path = ground_truth['images_path']
        episode_starts = training_data['episode_starts']
        actions = training_data['actions']
        actions_proba = training_data['actions_proba']

        if USE_ADAPTIVE_TEMPERATURE:
            cl_labels = training_data[CL_LABEL_KEY]
        else:
            cl_labels_st = None

        if args.distillation_training_set_size > 0:
            limit = args.distillation_training_set_size
            actions = actions[:limit]
            images_path = images_path[:limit]
            episode_starts = episode_starts[:limit]

        images_path_copy = [images_path[k] for k in range(images_path.shape[0])]
        images_path = np.array(images_path_copy)

        num_samples = images_path.shape[0] - 1  # number of samples
        if args.img_shape is None:
            self.img_shape = None #(3,224,224)
        else:
            self.img_shape = tuple(map(int, args.img_shape[1:-1].split(",")))

        # indices for all time steps where the episode continues
        indices = np.array([i for i in range(num_samples) if not episode_starts[i + 1]], dtype='int64')
        np.random.shuffle(indices)

        # split indices into minibatches. minibatchlist is a list of lists; each
        # list is the id of the observation preserved through the training
        minibatchlist = [np.array(sorted(indices[start_idx:start_idx + self.batch_size]))
                         for start_idx in range(0, len(indices) - self.batch_size + 1, self.batch_size)]
        data_loader = DataLoader(minibatchlist, images_path, self.img_shape, n_workers=N_WORKERS, multi_view=False,
                                 use_triplets=False, is_training=True,absolute_path=False)

        test_minibatchlist = DataLoader.createTestMinibatchList(len(images_path), MAX_BATCH_SIZE_GPU)

        test_data_loader = DataLoader(test_minibatchlist, images_path, self.img_shape, n_workers=N_WORKERS, multi_view=False,
                                      use_triplets=False, max_queue_len=1, is_training=False,absolute_path=False)

        # Number of minibatches used for validation:
        n_val_batches = np.round(VALIDATION_SIZE * len(minibatchlist)).astype(np.int64)
        val_indices = np.random.permutation(len(minibatchlist))[:n_val_batches]
        # Print some info
        print("{} minibatches for training, {} samples".format(len(minibatchlist) - n_val_batches,
                                                               (len(minibatchlist) - n_val_batches) * BATCH_SIZE))
        print("{} minibatches for validation, {} samples".format(n_val_batches, n_val_batches * BATCH_SIZE))
        assert n_val_batches > 0, "Not enough sample to create a validation set"

        # Stats about actions
        if not args.continuous_actions:
            print('Discrete action space:')
            action_set = set(actions)
            n_actions = int(np.max(actions) + 1)
            print("{} unique actions / {} actions".format(len(action_set), n_actions))
            n_obs_per_action = np.zeros(n_actions, dtype=np.int64)
            for i in range(n_actions):
                n_obs_per_action[i] = np.sum(actions == i)

            print("Number of observations per action")
            print(n_obs_per_action)

        else:
            print('Continuous action space:')
            print('Action dimension: {}'.format(self.dim_action))

        # Here the default SRL model is assumed to be raw_pixels
        self.state_dim = self.img_shape[0] * self.img_shape[1] * self.img_shape[2]                                                   # why                                                    
        self.srl_model = None
        print("env_kwargs[srl_model] ",env_kwargs["srl_model"])
        # TODO: add sanity checks & test for all possible SRL for distillation
        if env_kwargs["srl_model"] == "raw_pixels":                                                       
            # if the pilicy distillation is used with raw pixel
            self.model = CNNPolicy(n_actions, self.img_shape)
            learnable_params = self.model.parameters()
            learning_rate = 1e-3

        else:
            self.state_dim = getSRLDim(env_kwargs.get("srl_model_path", None))
            self.srl_model = loadSRLModel(env_kwargs.get("srl_model_path", None),
                                          th.cuda.is_available(), self.state_dim, env_object=None)

            self.model = MLPPolicy(output_size=n_actions, input_size=self.state_dim)
            for param in self.model.parameters():
                param.requires_grad = True
            learnable_params = [param for param in self.model.parameters()]

            if FINE_TUNING and self.srl_model is not None:
                for param in self.srl_model.model.parameters():
                    param.requires_grad = True
                learnable_params += [param for param in self.srl_model.model.parameters()]

            learning_rate = 1e-3
        self.device = th.device("cuda" if th.cuda.is_available() else "cpu")

        if th.cuda.is_available():
            self.model.cuda()

        self.optimizer = th.optim.Adam(learnable_params, lr=learning_rate)
        best_error = np.inf
        best_model_path = "{}/{}_model.pkl".format(args.log_dir, args.algo)

        for epoch in range(N_EPOCHS):
            # In each epoch, we do a full pass over the training data:
            epoch_loss, epoch_batches = 0, 0
            val_loss = 0
            pbar = tqdm(total=len(minibatchlist))

            for minibatch_num, (minibatch_idx, obs, _, _, _) in enumerate(data_loader):
                self.optimizer.zero_grad()

                obs = obs.to(self.device)
                validation_mode = minibatch_idx in val_indices
                if validation_mode:
                    self.model.eval()
                    if FINE_TUNING and self.srl_model is not None:
                        self.srl_model.model.eval()
                else:
                    self.model.train()
                    if FINE_TUNING and self.srl_model is not None:
                        self.srl_model.model.train()

                # Actions associated to the observations of the current minibatch
                actions_st = actions[minibatchlist[minibatch_idx]]
                actions_proba_st = actions_proba[minibatchlist[minibatch_idx]]

                if USE_ADAPTIVE_TEMPERATURE:
                    cl_labels_st = cl_labels[minibatchlist[minibatch_idx]]

                if not args.continuous_actions:
                    # Discrete actions, rearrange action to have n_minibatch ligns and one column,
                    # containing the int action
                    actions_st = one_hot(th.from_numpy(actions_st)).requires_grad_(False).to(self.device)
                    actions_proba_st = th.from_numpy(actions_proba_st).requires_grad_(False).to(self.device)
                else:
                    a = 0
                    # Continuous actions, rearrange action to have n_minibatch ligns and dim_action columns
                    actions_st = th.from_numpy(actions_st).view(-1, self.dim_action).requires_grad_(False).to(
                        self.device)

                if self.srl_model is not None:
                    state = self.srl_model.model.getStates(obs).to(self.device).detach()
                    if "autoencoder" in self.srl_model.model.losses:
                        use_ae = True
                        decoded_obs = self.srl_model.model.model.decode(state).to(self.device).detach()
                else:
                    state = obs.detach()
                pred_action = self.model.forward(state)

                loss = self.loss_fn_kd(pred_action,
                                        actions_proba_st.float(),
                                        labels=cl_labels_st, adaptive_temperature=USE_ADAPTIVE_TEMPERATURE)

                #loss = self.loss_mse(pred_action,  actions_proba_st.float())

                
                if validation_mode:
                    val_loss += loss.item()
                    # We do not optimize on validation data
                    # so optimizer.step() is not called
                else:
                    loss.backward()
                    self.optimizer.step()
                    epoch_loss += loss.item()
                    epoch_batches += 1
                pbar.update(1)

            train_loss = epoch_loss / float(epoch_batches)
            val_loss /= float(n_val_batches)
            pbar.close()
            print("Epoch {:3}/{}, train_loss:{:.6f} val_loss:{:.6f}".format(epoch + 1, N_EPOCHS, train_loss, val_loss))

            # Save best model
            if val_loss < best_error:
                best_error = val_loss
                self.save(best_model_path)