Exemple #1
0
    def __init__(
        self,
        args,
    ):
        self.args = args

        utils.seed(self.args.seed)

        # args.env is a list when training on multiple environments
        if getattr(args, 'multi_env', None):
            self.env = [gym.make(item) for item in args.multi_env]

            self.train_demos = []
            for demos, episodes in zip(args.multi_demos, args.multi_episodes):
                demos_path = utils.get_demos_path(demos,
                                                  None,
                                                  None,
                                                  valid=False)
                logger.info('loading {} of {} demos'.format(episodes, demos))
                train_demos = utils.load_demos(demos_path)
                logger.info('loaded demos')
                if episodes > len(train_demos):
                    raise ValueError(
                        "there are only {} train demos in {}".format(
                            len(train_demos), demos))
                self.train_demos.extend(train_demos[:episodes])
                logger.info('So far, {} demos loaded'.format(
                    len(self.train_demos)))

            self.val_demos = []
            for demos, episodes in zip(args.multi_demos, [args.val_episodes] *
                                       len(args.multi_demos)):
                demos_path_valid = utils.get_demos_path(demos,
                                                        None,
                                                        None,
                                                        valid=True)
                logger.info('loading {} of {} valid demos'.format(
                    episodes, demos))
                valid_demos = utils.load_demos(demos_path_valid)
                logger.info('loaded demos')
                if episodes > len(valid_demos):
                    logger.info(
                        'Using all the available {} demos to evaluate valid. accuracy'
                        .format(len(valid_demos)))
                self.val_demos.extend(valid_demos[:episodes])
                logger.info('So far, {} valid demos loaded'.format(
                    len(self.val_demos)))

            logger.info('Loaded all demos')

            observation_space = self.env[0].observation_space
            action_space = self.env[0].action_space

        else:
            self.env = gym.make(self.args.env)

            demos_path = utils.get_demos_path(args.demos,
                                              args.env,
                                              args.demos_origin,
                                              valid=False)
            demos_path_valid = utils.get_demos_path(args.demos,
                                                    args.env,
                                                    args.demos_origin,
                                                    valid=True)
            print("else")
            logger.info('loading demos')
            self.train_demos = utils.load_demos(demos_path)
            print(len(self.train_demos))
            print(self.train_demos[0])
            logger.info('loaded demos')
            if args.episodes:
                if args.episodes > len(self.train_demos):
                    raise ValueError("there are only {} train demos".format(
                        len(self.train_demos)))
                self.train_demos = self.train_demos[:args.episodes]

            self.val_demos = utils.load_demos(demos_path_valid)
            if args.val_episodes > len(self.val_demos):
                logger.info(
                    'Using all the available {} demos to evaluate valid. accuracy'
                    .format(len(self.val_demos)))
            self.val_demos = self.val_demos[:self.args.val_episodes]

            observation_space = self.env.observation_space
            action_space = self.env.action_space

            print("else")
        print(args.model)
        self.obss_preprocessor = utils.ObssPreprocessor(
            args.model, observation_space,
            getattr(self.args, 'pretrained_model', None))

        # Define actor-critic model
        self.acmodel = utils.load_model(args.model, raise_not_found=False)
        if self.acmodel is None:
            if getattr(self.args, 'pretrained_model', None):
                self.acmodel = utils.load_model(args.pretrained_model,
                                                raise_not_found=True)
            else:
                self.acmodel = ACModel(self.obss_preprocessor.obs_space,
                                       action_space, args.image_dim,
                                       args.memory_dim, args.instr_dim,
                                       not self.args.no_instr,
                                       self.args.instr_arch,
                                       not self.args.no_mem, self.args.arch)
        self.obss_preprocessor.vocab.save()
        utils.save_model(self.acmodel, args.model)

        self.acmodel.train()
        if torch.cuda.is_available():
            self.acmodel.cuda()

        self.optimizer = torch.optim.Adam(self.acmodel.parameters(),
                                          self.args.lr,
                                          eps=self.args.optim_eps)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         step_size=100,
                                                         gamma=0.9)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
Exemple #2
0
logger = logging.getLogger(__name__)

# Define obss preprocessor
if 'emb' in args.arch:
    obss_preprocessor = utils.IntObssPreprocessor(args.model, envs[0].observation_space, args.pretrained_model)
else:
    obss_preprocessor = utils.ObssPreprocessor(args.model, envs[0].observation_space, args.pretrained_model)

# Define actor-critic model
acmodel = utils.load_model(args.model, raise_not_found=False)
if acmodel is None:
    if args.pretrained_model:
        acmodel = utils.load_model(args.pretrained_model, raise_not_found=True)
    else:
        acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
                          args.image_dim, args.memory_dim, args.instr_dim,
                          not args.no_instr, args.instr_arch, not args.no_mem, args.arch)

obss_preprocessor.vocab.save()
utils.save_model(acmodel, args.model)

if torch.cuda.is_available():
    acmodel.cuda()

# Define actor-critic algo

reshape_reward = lambda _0, _1, reward, _2: args.reward_scale * reward
if args.algo == "ppo":
    algo = babyai.rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.beta1, args.beta2,
                             args.gae_lambda,
                             args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
def run_experiment(**config):
    set_seed(config['seed'])
    original_saved_path = config['saved_path']
    if original_saved_path is not None:
        saved_model = joblib.load(config['saved_path'])
        if 'config' in saved_model:
            if not config['override_old_config']:
                config = saved_model['config']
    arguments = {
        "start_loc": 'all',
        "include_holdout_obj": False,
        "persist_goal": config['persist_goal'],
        "persist_objs": config['persist_objs'],
        "persist_agent": config['persist_agent'],
        "feedback_type": config["feedback_type"],
        "feedback_always": config["feedback_always"],
        "feedback_freq": config["feedback_freq"],
        "cartesian_steps": config["cartesian_steps"],
        "num_meta_tasks": config["rollouts_per_meta_task"],
        "intermediate_reward": config["intermediate_reward"],
    }
    advice_start_index = 160
    if original_saved_path is not None:
        set_seed(config['seed'])
        policy = saved_model['policy']
        optimizer = saved_model['optimizer']
        policy.device = torch.device("cuda" if torch.cuda.is_available() else
                                     "cpu")  # TODO: is this necessary?
        policy.hidden_state = None
        baseline = saved_model['baseline']
        curriculum_step = saved_model['curriculum_step']
        env = rl2env(normalize(
            Curriculum(config['advance_curriculum_func'],
                       start_index=curriculum_step,
                       **arguments)),
                     ceil_reward=config['ceil_reward'])
        start_itr = saved_model['itr']
        reward_predictor = saved_model['reward_predictor']
        reward_predictor.hidden_state = None
        if 'supervised_model' in saved_model:
            supervised_model = saved_model['supervised_model']
        else:
            supervised_model = None

        teacher_train_dict = {}
        for teacher_name in config['feedback_type']:
            teacher_train_dict[teacher_name] = True

    else:

        teacher_train_dict = {}
        for teacher_name in config['feedback_type']:
            teacher_train_dict[teacher_name] = True

        optimizer = None
        baseline = None
        env = rl2env(normalize(
            Curriculum(config['advance_curriculum_func'],
                       start_index=config['level'],
                       **arguments)),
                     ceil_reward=config['ceil_reward'])
        obs = env.reset()
        obs_dim = 100  # TODO: consider changing this with 'additional' and adding it!
        advice_size = sum(
            [np.prod(obs[adv_k].shape) for adv_k in teacher_train_dict.keys()])

        image_dim = 128
        memory_dim = config['memory_dim']
        instr_dim = config['instr_dim']
        use_instr = True
        instr_arch = 'bigru'
        use_mem = True
        arch = 'bow_endpool_res'
        advice_dim = 128  # TODO: move this to the config
        policy = ACModel(obs_space=obs_dim,
                         action_space=env.action_space,
                         env=env,
                         image_dim=image_dim,
                         memory_dim=memory_dim,
                         instr_dim=instr_dim,
                         lang_model=instr_arch,
                         use_instr=use_instr,
                         use_memory=use_mem,
                         arch=arch,
                         advice_dim=advice_dim,
                         advice_size=advice_size,
                         num_modules=config['num_modules'])

        reward_predictor = ACModel(
            obs_space=obs_dim -
            1,  # TODO: change into Discrete(3) and do 3-way classification
            action_space=spaces.Discrete(2),
            env=env,
            image_dim=image_dim,
            memory_dim=memory_dim,
            instr_dim=instr_dim,
            lang_model=instr_arch,
            use_instr=use_instr,
            use_memory=use_mem,
            arch=arch,
            advice_dim=advice_dim,
            advice_size=advice_size,
            num_modules=config['num_modules'])
        if config['self_distill'] and not config['distill_same_model']:
            obs_dim = env.reset()['obs'].shape[0]
            image_dim = 128
            memory_dim = config['memory_dim']
            instr_dim = config['instr_dim']
            use_instr = True
            instr_arch = 'bigru'
            use_mem = True
            arch = 'bow_endpool_res'
            supervised_model = ACModel(obs_space=obs_dim - 1,
                                       action_space=env.action_space,
                                       env=env,
                                       image_dim=image_dim,
                                       memory_dim=memory_dim,
                                       instr_dim=instr_dim,
                                       lang_model=instr_arch,
                                       use_instr=use_instr,
                                       use_memory=use_mem,
                                       arch=arch,
                                       advice_dim=advice_dim,
                                       advice_size=advice_size,
                                       num_modules=config['num_modules'])
        elif config['self_distill']:
            supervised_model = policy
        else:
            supervised_model = None
        start_itr = 0
        curriculum_step = env.index
    parser = ArgumentParser()
    args = parser.parse_args([])
    args.entropy_coef = config['entropy_bonus']
    args.model = 'default_il'
    args.lr = config['learning_rate']
    args.recurrence = config['backprop_steps']
    args.clip_eps = config['clip_eps']
    if supervised_model is not None:
        il_trainer = ImitationLearning(
            supervised_model,
            env,
            args,
            distill_with_teacher=config['distill_with_teacher'])
    else:
        il_trainer = None
    rp_trainer = ImitationLearning(reward_predictor,
                                   env,
                                   args,
                                   distill_with_teacher=True,
                                   reward_predictor=True)

    teacher_null_dict = env.teacher.null_feedback()
    obs_preprocessor = make_obs_preprocessor(teacher_null_dict)

    sampler = MetaSampler(
        env=env,
        policy=policy,
        rollouts_per_meta_task=config['rollouts_per_meta_task'],
        meta_batch_size=config['meta_batch_size'],
        max_path_length=config['max_path_length'],
        parallel=config['parallel'],
        envs_per_task=1,
        reward_predictor=reward_predictor,
        supervised_model=supervised_model,
        obs_preprocessor=obs_preprocessor,
    )

    sample_processor = RL2SampleProcessor(
        baseline=baseline,
        discount=config['discount'],
        gae_lambda=config['gae_lambda'],
        normalize_adv=config['normalize_adv'],
        positive_adv=config['positive_adv'],
    )

    envs = [
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
        copy.deepcopy(env),
    ]
    algo = PPOAlgo(policy,
                   envs,
                   config['frames_per_proc'],
                   config['discount'],
                   args.lr,
                   args.beta1,
                   args.beta2,
                   config['gae_lambda'],
                   args.entropy_coef,
                   config['value_loss_coef'],
                   config['max_grad_norm'],
                   args.recurrence,
                   args.optim_eps,
                   config['clip_eps'],
                   config['epochs'],
                   config['meta_batch_size'],
                   parallel=config['parallel'],
                   rollouts_per_meta_task=config['rollouts_per_meta_task'],
                   obs_preprocessor=obs_preprocessor)

    if optimizer is not None:
        algo.optimizer.load_state_dict(optimizer)

    EXP_NAME = get_exp_name(config)
    exp_dir = os.getcwd() + '/data/' + EXP_NAME + "_" + str(config['seed'])
    if original_saved_path is None:
        if os.path.isdir(exp_dir):
            shutil.rmtree(exp_dir)
    log_formats = ['stdout', 'log', 'csv']
    is_debug = config['prefix'] == 'DEBUG'

    if not is_debug:
        log_formats.append('tensorboard')
        log_formats.append('wandb')
    logger.configure(dir=exp_dir,
                     format_strs=log_formats,
                     snapshot_mode=config['save_option'],
                     snapshot_gap=50,
                     step=start_itr,
                     name=config['prefix'] + str(config['seed']),
                     config=config)
    json.dump(config,
              open(exp_dir + '/params.json', 'w'),
              indent=2,
              sort_keys=True,
              cls=ClassEncoder)

    advice_end_index, advice_dim = 161, 1
    if config[
            'distill_with_teacher']:  # TODO: generalize this for multiple feedback types at once!
        teacher_info = []
    else:
        null_val = np.zeros(advice_end_index - advice_start_index)
        if len(null_val) > 0:
            null_val[-1] = 1
        teacher_info = [{
            "indices":
            np.arange(advice_start_index, advice_end_index),
            "null":
            null_val
        }]

    trainer = Trainer(
        algo=algo,
        policy=policy,
        env=deepcopy(env),
        sampler=sampler,
        sample_processor=sample_processor,
        n_itr=config['n_itr'],
        start_itr=start_itr,
        success_threshold=config['success_threshold'],
        accuracy_threshold=config['accuracy_threshold'],
        exp_name=exp_dir,
        curriculum_step=curriculum_step,
        config=config,
        advance_without_teacher=True,
        teacher_info=teacher_info,
        sparse_rewards=not config['intermediate_reward'],
        distill_only=config['distill_only'],
        il_trainer=il_trainer,
        source=config['source'],
        batch_size=config['meta_batch_size'],
        train_with_teacher=config['feedback_type'] is not None,
        distill_with_teacher=config['distill_with_teacher'],
        supervised_model=supervised_model,
        reward_predictor=reward_predictor,
        rp_trainer=rp_trainer,
        advance_levels=config['advance_levels'],
        is_debug=is_debug,
        teacher_train_dict=teacher_train_dict,
        obs_preprocessor=obs_preprocessor,
    )
    trainer.train()
Exemple #4
0
class ImitationLearning(object):
    def __init__(
        self,
        args,
    ):
        self.args = args

        utils.seed(self.args.seed)

        # args.env is a list when training on multiple environments
        if getattr(args, 'multi_env', None):
            self.env = [gym.make(item) for item in args.multi_env]

            self.train_demos = []
            for demos, episodes in zip(args.multi_demos, args.multi_episodes):
                demos_path = utils.get_demos_path(demos,
                                                  None,
                                                  None,
                                                  valid=False)
                logger.info('loading {} of {} demos'.format(episodes, demos))
                train_demos = utils.load_demos(demos_path)
                logger.info('loaded demos')
                if episodes > len(train_demos):
                    raise ValueError(
                        "there are only {} train demos in {}".format(
                            len(train_demos), demos))
                self.train_demos.extend(train_demos[:episodes])
                logger.info('So far, {} demos loaded'.format(
                    len(self.train_demos)))

            self.val_demos = []
            for demos, episodes in zip(args.multi_demos, [args.val_episodes] *
                                       len(args.multi_demos)):
                demos_path_valid = utils.get_demos_path(demos,
                                                        None,
                                                        None,
                                                        valid=True)
                logger.info('loading {} of {} valid demos'.format(
                    episodes, demos))
                valid_demos = utils.load_demos(demos_path_valid)
                logger.info('loaded demos')
                if episodes > len(valid_demos):
                    logger.info(
                        'Using all the available {} demos to evaluate valid. accuracy'
                        .format(len(valid_demos)))
                self.val_demos.extend(valid_demos[:episodes])
                logger.info('So far, {} valid demos loaded'.format(
                    len(self.val_demos)))

            logger.info('Loaded all demos')

            observation_space = self.env[0].observation_space
            action_space = self.env[0].action_space

        else:
            self.env = gym.make(self.args.env)

            demos_path = utils.get_demos_path(args.demos,
                                              args.env,
                                              args.demos_origin,
                                              valid=False)
            demos_path_valid = utils.get_demos_path(args.demos,
                                                    args.env,
                                                    args.demos_origin,
                                                    valid=True)
            print("else")
            logger.info('loading demos')
            self.train_demos = utils.load_demos(demos_path)
            print(len(self.train_demos))
            print(self.train_demos[0])
            logger.info('loaded demos')
            if args.episodes:
                if args.episodes > len(self.train_demos):
                    raise ValueError("there are only {} train demos".format(
                        len(self.train_demos)))
                self.train_demos = self.train_demos[:args.episodes]

            self.val_demos = utils.load_demos(demos_path_valid)
            if args.val_episodes > len(self.val_demos):
                logger.info(
                    'Using all the available {} demos to evaluate valid. accuracy'
                    .format(len(self.val_demos)))
            self.val_demos = self.val_demos[:self.args.val_episodes]

            observation_space = self.env.observation_space
            action_space = self.env.action_space

            print("else")
        print(args.model)
        self.obss_preprocessor = utils.ObssPreprocessor(
            args.model, observation_space,
            getattr(self.args, 'pretrained_model', None))

        # Define actor-critic model
        self.acmodel = utils.load_model(args.model, raise_not_found=False)
        if self.acmodel is None:
            if getattr(self.args, 'pretrained_model', None):
                self.acmodel = utils.load_model(args.pretrained_model,
                                                raise_not_found=True)
            else:
                self.acmodel = ACModel(self.obss_preprocessor.obs_space,
                                       action_space, args.image_dim,
                                       args.memory_dim, args.instr_dim,
                                       not self.args.no_instr,
                                       self.args.instr_arch,
                                       not self.args.no_mem, self.args.arch)
        self.obss_preprocessor.vocab.save()
        utils.save_model(self.acmodel, args.model)

        self.acmodel.train()
        if torch.cuda.is_available():
            self.acmodel.cuda()

        self.optimizer = torch.optim.Adam(self.acmodel.parameters(),
                                          self.args.lr,
                                          eps=self.args.optim_eps)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         step_size=100,
                                                         gamma=0.9)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

    @staticmethod
    def default_model_name(args):
        if getattr(args, 'multi_env', None):
            # It's better to specify one's own model name for this scenario
            named_envs = '-'.join(args.multi_env)
        else:
            named_envs = args.env

        # Define model name
        suffix = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
        instr = args.instr_arch if args.instr_arch else "noinstr"
        model_name_parts = {
            'envs': named_envs,
            'arch': args.arch,
            'instr': instr,
            'seed': args.seed,
            'suffix': suffix
        }
        default_model_name = "{envs}_IL_{arch}_{instr}_seed{seed}_{suffix}".format(
            **model_name_parts)
        if getattr(args, 'pretrained_model', None):
            default_model_name = args.pretrained_model + '_pretrained_' + default_model_name
        return default_model_name

    def starting_indexes(self, num_frames):
        if num_frames % self.args.recurrence == 0:
            return np.arange(0, num_frames, self.args.recurrence)
        else:
            return np.arange(0, num_frames, self.args.recurrence)[:-1]

    def run_epoch_recurrence(self, demos, is_training=False):
        indices = list(range(len(demos)))
        if is_training:
            np.random.shuffle(indices)
        batch_size = min(self.args.batch_size, len(demos))
        offset = 0

        if not is_training:
            self.acmodel.eval()

        # Log dictionary
        log = {"entropy": [], "policy_loss": [], "accuracy": []}

        start_time = time.time()
        frames = 0
        for batch_index in range(len(indices) // batch_size):
            logger.info("batch {}, FPS so far {}".format(
                batch_index,
                frames / (time.time() - start_time) if frames else 0))
            batch = [demos[i] for i in indices[offset:offset + batch_size]]
            frames += sum([len(demo[3]) for demo in batch])

            _log = self.run_epoch_recurrence_one_batch(batch,
                                                       is_training=is_training)

            log["entropy"].append(_log["entropy"])
            log["policy_loss"].append(_log["policy_loss"])
            log["accuracy"].append(_log["accuracy"])

            offset += batch_size

        if not is_training:
            self.acmodel.train()

        return log

    def run_epoch_recurrence_one_batch(self, batch, is_training=False):
        batch = utils.demos.transform_demos(batch)
        batch.sort(key=len, reverse=True)
        # Constructing flat batch and indices pointing to start of each demonstration
        flat_batch = []
        inds = [0]

        for demo in batch:
            flat_batch += demo
            inds.append(inds[-1] + len(demo))

        flat_batch = np.array(flat_batch)
        inds = inds[:-1]
        num_frames = len(flat_batch)

        mask = np.ones([len(flat_batch)], dtype=np.float64)
        mask[inds] = 0
        mask = torch.tensor(mask, device=self.device,
                            dtype=torch.float).unsqueeze(1)

        # Observations, true action, values and done for each of the stored demostration
        obss, action_true, done = flat_batch[:,
                                             0], flat_batch[:,
                                                            1], flat_batch[:,
                                                                           2]
        action_true = torch.tensor([action for action in action_true],
                                   device=self.device,
                                   dtype=torch.long)

        # Memory to be stored
        memories = torch.zeros([len(flat_batch), self.acmodel.memory_size],
                               device=self.device)
        episode_ids = np.zeros(len(flat_batch))
        memory = torch.zeros([len(batch), self.acmodel.memory_size],
                             device=self.device)

        preprocessed_first_obs = self.obss_preprocessor(obss[inds],
                                                        device=self.device)
        instr_embedding = self.acmodel._get_instr_embedding(
            preprocessed_first_obs.instr)

        # Loop terminates when every observation in the flat_batch has been handled
        while True:
            # taking observations and done located at inds
            obs = obss[inds]
            done_step = done[inds]
            preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
            with torch.no_grad():
                # taking the memory till len(inds), as demos beyond that have already finished
                new_memory = self.acmodel(
                    preprocessed_obs, memory[:len(inds), :],
                    instr_embedding[:len(inds)])['memory']

            memories[inds, :] = memory[:len(inds), :]
            memory[:len(inds), :] = new_memory
            episode_ids[inds] = range(len(inds))

            # Updating inds, by removing those indices corresponding to which the demonstrations have finished
            inds = inds[:len(inds) - sum(done_step)]
            if len(inds) == 0:
                break

            # Incrementing the remaining indices
            inds = [index + 1 for index in inds]

        # Here, actual backprop upto args.recurrence happens
        final_loss = 0
        final_entropy, final_policy_loss, final_value_loss = 0, 0, 0

        indexes = self.starting_indexes(num_frames)
        memory = memories[indexes]
        accuracy = 0
        total_frames = len(indexes) * self.args.recurrence
        for _ in range(self.args.recurrence):
            obs = obss[indexes]
            preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
            action_step = action_true[indexes]
            mask_step = mask[indexes]
            model_results = self.acmodel(preprocessed_obs, memory * mask_step,
                                         instr_embedding[episode_ids[indexes]])
            dist = model_results['dist']
            memory = model_results['memory']

            entropy = dist.entropy().mean()
            policy_loss = -dist.log_prob(action_step).mean()
            loss = policy_loss - self.args.entropy_coef * entropy
            action_pred = dist.probs.max(1, keepdim=True)[1]
            accuracy += float(
                (action_pred == action_step.unsqueeze(1)).sum()) / total_frames
            final_loss += loss
            final_entropy += entropy
            final_policy_loss += policy_loss
            indexes += 1

        final_loss /= self.args.recurrence

        if is_training:
            self.optimizer.zero_grad()
            final_loss.backward()
            self.optimizer.step()

        log = {}
        log["entropy"] = float(final_entropy / self.args.recurrence)
        log["policy_loss"] = float(final_policy_loss / self.args.recurrence)
        log["accuracy"] = float(accuracy)

        return log

    def validate(self, episodes, verbose=True):
        # Seed needs to be reset for each validation, to ensure consistency
        utils.seed(self.args.val_seed)

        if verbose:
            logger.info("Validating the model")
        if getattr(self.args, 'multi_env', None):
            agent = utils.load_agent(self.env[0],
                                     model_name=self.args.model,
                                     argmax=True)
        else:
            agent = utils.load_agent(self.env,
                                     model_name=self.args.model,
                                     argmax=True)

        # Setting the agent model to the current model
        agent.model = self.acmodel

        agent.model.eval()
        logs = []

        for env_name in ([self.args.env]
                         if not getattr(self.args, 'multi_env', None) else
                         self.args.multi_env):
            logs += [
                batch_evaluate(agent, env_name, self.args.val_seed, episodes)
            ]
        agent.model.train()

        return logs

    def collect_returns(self):
        logs = self.validate(episodes=self.args.eval_episodes, verbose=False)
        mean_return = {
            tid: np.mean(log["return_per_episode"])
            for tid, log in enumerate(logs)
        }
        return mean_return

    def train(self,
              train_demos,
              writer,
              csv_writer,
              status_path,
              header,
              reset_status=False):
        # Load the status
        def initial_status():
            return {'i': 0, 'num_frames': 0, 'patience': 0}

        status = initial_status()
        if os.path.exists(status_path) and not reset_status:
            with open(status_path, 'r') as src:
                status = json.load(src)
        elif not os.path.exists(os.path.dirname(status_path)):
            # Ensure that the status directory exists
            os.makedirs(os.path.dirname(status_path))

        # If the batch size is larger than the number of demos, we need to lower the batch size
        if self.args.batch_size > len(train_demos):
            self.args.batch_size = len(train_demos)
            logger.info(
                "Batch size too high. Setting it to the number of train demos ({})"
                .format(len(train_demos)))

        # Model saved initially to avoid "Model not found Exception" during first validation step
        utils.save_model(self.acmodel, self.args.model)

        # best mean return to keep track of performance on validation set
        best_success_rate, patience, i = 0, 0, 0
        total_start_time = time.time()

        while status['i'] < getattr(self.args, 'epochs', int(1e9)):
            if 'patience' not in status:  # if for some reason you're finetuining with IL an RL pretrained agent
                status['patience'] = 0
            # Do not learn if using a pre-trained model that already lost patience
            if status['patience'] > self.args.patience:
                break
            if status['num_frames'] > self.args.frames:
                break

            status['i'] += 1
            i = status['i']
            update_start_time = time.time()

            # Learning rate scheduler
            self.scheduler.step()

            log = self.run_epoch_recurrence(train_demos, is_training=True)
            total_len = sum([len(item[3]) for item in train_demos])
            status['num_frames'] += total_len

            update_end_time = time.time()

            # Print logs
            if status['i'] % self.args.log_interval == 0:
                total_ellapsed_time = int(time.time() - total_start_time)

                fps = total_len / (update_end_time - update_start_time)
                duration = datetime.timedelta(seconds=total_ellapsed_time)

                for key in log:
                    log[key] = np.mean(log[key])

                train_data = [
                    status['i'], status['num_frames'], fps,
                    total_ellapsed_time, log["entropy"], log["policy_loss"],
                    log["accuracy"]
                ]

                logger.info(
                    "U {} | F {:06} | FPS {:04.0f} | D {} | H {:.3f} | pL {: .3f} | A {: .3f}"
                    .format(*train_data))

                # Log the gathered data only when we don't evaluate the validation metrics. It will be logged anyways
                # afterwards when status['i'] % self.args.val_interval == 0
                if status['i'] % self.args.val_interval != 0:
                    # instantiate a validation_log with empty strings when no validation is done
                    validation_data = [''] * len(
                        [key for key in header if 'valid' in key])
                    assert len(header) == len(train_data + validation_data)
                    if self.args.tb:
                        for key, value in zip(header, train_data):
                            writer.add_scalar(key, float(value),
                                              status['num_frames'])
                    csv_writer.writerow(train_data + validation_data)

            if status['i'] % self.args.val_interval == 0:

                valid_log = self.validate(self.args.val_episodes)
                mean_return = [
                    np.mean(log['return_per_episode']) for log in valid_log
                ]
                success_rate = [
                    np.mean(
                        [1 if r > 0 else 0 for r in log['return_per_episode']])
                    for log in valid_log
                ]

                val_log = self.run_epoch_recurrence(self.val_demos)
                validation_accuracy = np.mean(val_log["accuracy"])

                if status['i'] % self.args.log_interval == 0:
                    validation_data = [validation_accuracy
                                       ] + mean_return + success_rate
                    logger.info(("Validation: A {: .3f} " +
                                 ("| R {: .3f} " * len(mean_return) +
                                  "| S {: .3f} " * len(success_rate))).format(
                                      *validation_data))

                    assert len(header) == len(train_data + validation_data)
                    if self.args.tb:
                        for key, value in zip(header,
                                              train_data + validation_data):
                            writer.add_scalar(key, float(value),
                                              status['num_frames'])
                    csv_writer.writerow(train_data + validation_data)

                # In case of a multi-env, the update condition would be "better mean success rate" !
                if np.mean(success_rate) > best_success_rate:
                    best_success_rate = np.mean(success_rate)
                    status['patience'] = 0
                    with open(status_path, 'w') as dst:
                        json.dump(status, dst)
                    # Saving the model
                    logger.info("Saving best model")

                    if torch.cuda.is_available():
                        self.acmodel.cpu()
                    utils.save_model(self.acmodel, self.args.model + "_best")
                    self.obss_preprocessor.vocab.save(
                        utils.get_vocab_path(self.args.model + "_best"))
                    if torch.cuda.is_available():
                        self.acmodel.cuda()
                else:
                    status['patience'] += 1
                    logger.info(
                        "Losing patience, new value={}, limit={}".format(
                            status['patience'], self.args.patience))

                if torch.cuda.is_available():
                    self.acmodel.cpu()
                utils.save_model(self.acmodel, self.args.model)
                if torch.cuda.is_available():
                    self.acmodel.cuda()
                with open(status_path, 'w') as dst:
                    json.dump(status, dst)
Exemple #5
0
        load_vocab_from=args.pretrained_model)

# Define actor-critic model
acmodel = utils.load_model(args.model, raise_not_found=False)
if acmodel is None:
    if args.pretrained_model:
        acmodel = utils.load_model(args.pretrained_model, raise_not_found=True)
    elif 'gie' in args.instr_arch:
        acmodel = ACModel(obss_preprocessor.obs_space,
                          envs[0].action_space,
                          args.image_dim,
                          args.memory_dim,
                          args.instr_dim,
                          not args.no_instr,
                          args.instr_arch,
                          not args.no_mem,
                          args.arch,
                          gie_pretrained_emb=args.gie_pretrained_emb,
                          gie_freeze_emb=args.gie_freeze_emb,
                          gie_aggr_method=args.gie_aggr_method,
                          gie_message_rounds=args.gie_message_rounds,
                          gie_two_layers=args.gie_two_layers,
                          gie_heads=args.gie_heads)
    else:
        acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
                          args.image_dim, args.memory_dim, args.instr_dim,
                          not args.no_instr, args.instr_arch, not args.no_mem,
                          args.arch)

obss_preprocessor.vocab.save()
utils.save_model(acmodel, args.model)
Exemple #6
0
def main(exp, argv):
    os.environ["BABYAI_STORAGE"] = exp.results_directory()

    # Parse arguments
    parser = ArgumentParser()
    parser.add_argument("--algo",
                        default='ppo',
                        help="algorithm to use (default: ppo)")
    parser.add_argument("--discount",
                        type=float,
                        default=0.99,
                        help="discount factor (default: 0.99)")
    parser.add_argument("--reward-scale",
                        type=float,
                        default=20.,
                        help="Reward scale multiplier")
    parser.add_argument(
        "--gae-lambda",
        type=float,
        default=0.99,
        help="lambda coefficient in GAE formula (default: 0.99, 1 means no gae)"
    )
    parser.add_argument("--value-loss-coef",
                        type=float,
                        default=0.5,
                        help="value loss term coefficient (default: 0.5)")
    parser.add_argument("--max-grad-norm",
                        type=float,
                        default=0.5,
                        help="maximum norm of gradient (default: 0.5)")
    parser.add_argument("--clip-eps",
                        type=float,
                        default=0.2,
                        help="clipping epsilon for PPO (default: 0.2)")
    parser.add_argument("--ppo-epochs",
                        type=int,
                        default=4,
                        help="number of epochs for PPO (default: 4)")
    parser.add_argument(
        "--save-interval",
        type=int,
        default=50,
        help=
        "number of updates between two saves (default: 50, 0 means no saving)")
    parser.add_argument("--workers",
                        type=int,
                        default=8,
                        help="number of workers for PyTorch (default: 8)")
    parser.add_argument("--max-count",
                        type=int,
                        default=1000,
                        help="maximum number of frames to run for")
    parser.add_argument("--sample_duration",
                        type=float,
                        default=0.5,
                        help="sampling duration")
    parser.add_argument("--cuda",
                        action="store_true",
                        default=False,
                        help="whether to use cuda")
    args = parser.parse_args(argv)

    utils.seed(args.seed)

    torch_settings = init_torch(
        seed=args.seed,
        cuda=args.cuda,
        workers=args.workers,
    )

    # Generate environments
    envs = []
    for i in range(args.procs):
        env = gym.make(args.env)
        env.seed(100 * args.seed + i)
        envs.append(env)

    # Define model name
    suffix = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    instr = args.instr_arch if args.instr_arch else "noinstr"
    mem = "mem" if not args.no_mem else "nomem"
    model_name_parts = {
        'env': args.env,
        'algo': args.algo,
        'arch': args.arch,
        'instr': instr,
        'mem': mem,
        'seed': args.seed,
        'info': '',
        'coef': '',
        'suffix': suffix
    }
    default_model_name = "{env}_{algo}_{arch}_{instr}_{mem}_seed{seed}{info}{coef}_{suffix}".format(
        **model_name_parts)
    if args.pretrained_model:
        default_model_name = args.pretrained_model + '_pretrained_' + default_model_name
    args.model = args.model.format(
        **model_name_parts) if args.model else default_model_name

    utils.configure_logging(args.model)
    logger = logging.getLogger(__name__)

    # Define obss preprocessor
    if 'emb' in args.arch:
        obss_preprocessor = utils.IntObssPreprocessor(
            args.model, envs[0].observation_space, args.pretrained_model)
    else:
        obss_preprocessor = utils.ObssPreprocessor(args.model,
                                                   envs[0].observation_space,
                                                   args.pretrained_model)

    # Define actor-critic model
    # acmodel = utils.load_model(args.model, raise_not_found=False)
    acmodel = None
    if acmodel is None:
        if args.pretrained_model:
            acmodel = utils.load_model(args.pretrained_model,
                                       raise_not_found=True)
        else:
            acmodel = ACModel(obss_preprocessor.obs_space,
                              envs[0].action_space, args.image_dim,
                              args.memory_dim, args.instr_dim,
                              not args.no_instr, args.instr_arch,
                              not args.no_mem, args.arch)

    obss_preprocessor.vocab.save()
    # utils.save_model(acmodel, args.model)

    if torch_settings.cuda:
        acmodel.cuda()

    # Define actor-critic algo

    reshape_reward = lambda _0, _1, reward, _2: args.reward_scale * reward
    if args.algo == "ppo":
        algo = babyai.rl.PPOAlgo(
            envs, acmodel, args.frames_per_proc, args.discount, args.lr,
            args.beta1, args.beta2, args.gae_lambda, args.entropy_coef,
            args.value_loss_coef, args.max_grad_norm, args.recurrence,
            args.optim_eps, args.clip_eps, args.ppo_epochs, args.batch_size,
            obss_preprocessor, reshape_reward)
    else:
        raise ValueError("Incorrect algorithm name: {}".format(args.algo))

    # When using extra binary information, more tensors (model params) are initialized compared to when we don't use that.
    # Thus, there starts to be a difference in the random state. If we want to avoid it, in order to make sure that
    # the results of supervised-loss-coef=0. and extra-binary-info=0 match, we need to reseed here.

    utils.seed(args.seed)

    # Restore training status

    status_path = os.path.join(utils.get_log_dir(args.model), 'status.json')
    if os.path.exists(status_path):
        with open(status_path, 'r') as src:
            status = json.load(src)
    else:
        status = {'i': 0, 'num_episodes': 0, 'num_frames': 0}

    # # Define logger and Tensorboard writer and CSV writer

    # header = (["update", "episodes", "frames", "FPS", "duration"]
    #         + ["return_" + stat for stat in ['mean', 'std', 'min', 'max']]
    #         + ["success_rate"]
    #         + ["num_frames_" + stat for stat in ['mean', 'std', 'min', 'max']]
    #         + ["entropy", "value", "policy_loss", "value_loss", "loss", "grad_norm"])
    # if args.tb:
    #     from tensorboardX import SummaryWriter

    #     writer = SummaryWriter(utils.get_log_dir(args.model))
    # csv_path = os.path.join(utils.get_log_dir(args.model), 'log.csv')
    # first_created = not os.path.exists(csv_path)
    # # we don't buffer data going in the csv log, cause we assume
    # # that one update will take much longer that one write to the log
    # csv_writer = csv.writer(open(csv_path, 'a', 1))
    # if first_created:
    #     csv_writer.writerow(header)

    # Log code state, command, availability of CUDA and model

    babyai_code = list(babyai.__path__)[0]
    try:
        last_commit = subprocess.check_output(
            'cd {}; git log -n1'.format(babyai_code),
            shell=True).decode('utf-8')
        logger.info('LAST COMMIT INFO:')
        logger.info(last_commit)
    except subprocess.CalledProcessError:
        logger.info('Could not figure out the last commit')
    try:
        diff = subprocess.check_output('cd {}; git diff'.format(babyai_code),
                                       shell=True).decode('utf-8')
        if diff:
            logger.info('GIT DIFF:')
            logger.info(diff)
    except subprocess.CalledProcessError:
        logger.info('Could not figure out the last commit')
    logger.info('COMMAND LINE ARGS:')
    logger.info(args)
    logger.info("CUDA available: {}".format(torch.cuda.is_available()))
    logger.info(acmodel)

    # Train model

    total_start_time = time.time()
    best_success_rate = 0
    best_mean_return = 0
    test_env_name = args.env

    wrapper = iteration_wrapper(
        exp,
        sync=torch_settings.sync,
        max_count=args.max_count,
        sample_duration=args.sample_duration,
    )

    # while status['num_frames'] < args.frames:
    while True:
        with wrapper() as it:
            # Update parameters
            if wrapper.done():
                break

            update_start_time = time.time()
            logs = algo.update_parameters()
            update_end_time = time.time()

            it.set_count(logs["num_frames"])
            it.log(loss=logs["loss"], )
                                                         instr,
                                                         mem,
                                                         args.seed,
                                                         fakerewardtxt,
                                                         suffix)
model_name = args.model or default_model_name

# Define obss preprocessor

obss_preprocessor = utils.ObssPreprocessor(model_name, envs[0].observation_space)

# Define actor-critic model

acmodel = utils.load_model(model_name, raise_not_found=False)
if acmodel is None:
    acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space,
                      args.instr_model, not args.no_mem, args.arch)
if torch.cuda.is_available():
    acmodel.cuda()

# Define actor-critic algo

if args.algo == "a2c":
    algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau,
                            args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                            args.optim_alpha, args.optim_eps, obss_preprocessor, utils.reshape_reward)
elif args.algo == "ppo":
    algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau,
                            args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence,
                            args.optim_eps, args.clip_eps, args.epochs, args.batch_size, obss_preprocessor,
                            utils.reshape_reward)
else:
Exemple #8
0
                                               args.pretrained_model)

# Define actor-critic model
acmodel = utils.load_model(args.model, raise_not_found=False)
if acmodel is None:
    if args.pretrained_model:
        acmodel = utils.load_model(args.pretrained_model, raise_not_found=True)
    else:
        advice_start_index = 160
        advice_end_index = advice_start_index + env.action_space.n + 1
        acmodel = ACModel(obss_preprocessor.obs_space,
                          envs[0].action_space,
                          envs[0],
                          args.image_dim,
                          args.memory_dim,
                          args.instr_dim,
                          not args.no_instr,
                          args.instr_arch,
                          not args.no_mem,
                          advice_dim=128,
                          advice_start_index=advice_start_index,
                          advice_end_index=advice_end_index)

obss_preprocessor.vocab.save()
utils.save_model(acmodel, args.model)

if torch.cuda.is_available():
    acmodel.cuda()

# Define actor-critic algo

reshape_reward = lambda _0, _1, reward, _2: args.reward_scale * reward
Exemple #9
0
    def __init__(self, args):
        """

        :param args:
        """
        super(MetaLearner, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.task_num = args.task_num
        self.args = args

        utils.seed(self.args.seed)

        self.env = gym.make(self.args.env)

        demos_path = utils.get_demos_path(args.demos,
                                          args.env,
                                          args.demos_origin,
                                          valid=False)
        demos_path_valid = utils.get_demos_path(args.demos,
                                                args.env,
                                                args.demos_origin,
                                                valid=True)

        logger.info('loading demos')
        self.train_demos = utils.load_demos(demos_path)
        logger.info('loaded demos')
        # if args.episodes:
        #     if args.episodes > len(self.train_demos):
        #         raise ValueError("there are only {} train demos".format(len(self.train_demos)))
        # self.train_demos = self.train_demos[:args.episodes]

        self.val_demos = utils.load_demos(demos_path_valid)
        # if args.val_episodes > len(self.val_demos):
        #     logger.info('Using all the available {} demos to evaluate valid. accuracy'.format(len(self.val_demos)))
        self.val_demos = self.val_demos[:self.args.val_episodes]

        observation_space = self.env.observation_space
        action_space = self.env.action_space

        print(args.model)
        self.obss_preprocessor = utils.ObssPreprocessor(
            args.model, observation_space,
            getattr(self.args, 'pretrained_model', None))

        # Define actor-critic model
        # self.net = utils.load_model(args.model, raise_not_found=False)
        # if self.net is None:
        #     if getattr(self.args, 'pretrained_model', None):
        #         self.net = utils.load_model(args.pretrained_model, raise_not_found=True)
        #     else:
        self.net = ACModel(self.obss_preprocessor.obs_space, action_space,
                           args.image_dim, args.memory_dim, args.instr_dim,
                           not self.args.no_instr, self.args.instr_arch,
                           not self.args.no_mem, self.args.arch)
        self.obss_preprocessor.vocab.save()
        # utils.save_model(self.net, args.model)
        self.fast_net = copy.deepcopy(self.net)
        self.net.train()
        self.fast_net.train()

        if torch.cuda.is_available():
            self.net.cuda()
            self.fast_net.cuda()

        self.optimizer = torch.optim.SGD(self.fast_net.parameters(),
                                         lr=self.args.update_lr)
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.9)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)
Exemple #10
0
class MetaLearner(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args):
        """

        :param args:
        """
        super(MetaLearner, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.task_num = args.task_num
        self.args = args

        utils.seed(self.args.seed)

        self.env = gym.make(self.args.env)

        demos_path = utils.get_demos_path(args.demos,
                                          args.env,
                                          args.demos_origin,
                                          valid=False)
        demos_path_valid = utils.get_demos_path(args.demos,
                                                args.env,
                                                args.demos_origin,
                                                valid=True)

        logger.info('loading demos')
        self.train_demos = utils.load_demos(demos_path)
        logger.info('loaded demos')
        # if args.episodes:
        #     if args.episodes > len(self.train_demos):
        #         raise ValueError("there are only {} train demos".format(len(self.train_demos)))
        # self.train_demos = self.train_demos[:args.episodes]

        self.val_demos = utils.load_demos(demos_path_valid)
        # if args.val_episodes > len(self.val_demos):
        #     logger.info('Using all the available {} demos to evaluate valid. accuracy'.format(len(self.val_demos)))
        self.val_demos = self.val_demos[:self.args.val_episodes]

        observation_space = self.env.observation_space
        action_space = self.env.action_space

        print(args.model)
        self.obss_preprocessor = utils.ObssPreprocessor(
            args.model, observation_space,
            getattr(self.args, 'pretrained_model', None))

        # Define actor-critic model
        # self.net = utils.load_model(args.model, raise_not_found=False)
        # if self.net is None:
        #     if getattr(self.args, 'pretrained_model', None):
        #         self.net = utils.load_model(args.pretrained_model, raise_not_found=True)
        #     else:
        self.net = ACModel(self.obss_preprocessor.obs_space, action_space,
                           args.image_dim, args.memory_dim, args.instr_dim,
                           not self.args.no_instr, self.args.instr_arch,
                           not self.args.no_mem, self.args.arch)
        self.obss_preprocessor.vocab.save()
        # utils.save_model(self.net, args.model)
        self.fast_net = copy.deepcopy(self.net)
        self.net.train()
        self.fast_net.train()

        if torch.cuda.is_available():
            self.net.cuda()
            self.fast_net.cuda()

        self.optimizer = torch.optim.SGD(self.fast_net.parameters(),
                                         lr=self.args.update_lr)
        # self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.9)

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)

    def clip_grad_by_norm_(self, grad, max_norm):
        """
        in-place gradient clipping.
        :param grad: list of gradients
        :param max_norm: maximum norm allowable
        :return:
        """

        total_norm = 0
        counter = 0
        for g in grad:
            param_norm = g.data.norm(2)
            total_norm += param_norm.item()**2
            counter += 1
        total_norm = total_norm**(1. / 2)

        clip_coef = max_norm / (total_norm + 1e-6)
        if clip_coef < 1:
            for g in grad:
                g.data.mul_(clip_coef)

        return total_norm / counter

    def starting_indexes(self, num_frames):
        if num_frames % self.args.recurrence == 0:
            return np.arange(0, num_frames, self.args.recurrence)
        else:
            return np.arange(0, num_frames, self.args.recurrence)[:-1]

    def forward_batch(self, batch, task, net='fast', is_training=True):
        if net == 'fast':
            acmodel = self.fast_net
        else:
            acmodel = self.net

        batch = utils.demos.induce_grammar(batch, task)

        batch = utils.demos.transform_demos(batch)
        batch.sort(key=len, reverse=True)
        # Constructing flat batch and indices pointing to start of each demonstration
        flat_batch = []
        inds = [0]

        for demo in batch:
            flat_batch += demo
            inds.append(inds[-1] + len(demo))

        flat_batch = np.array(flat_batch)
        inds = inds[:-1]
        num_frames = len(flat_batch)

        mask = np.ones([len(flat_batch)], dtype=np.float64)
        mask[inds] = 0
        mask = torch.tensor(mask, device=self.device,
                            dtype=torch.float).unsqueeze(1)

        # Observations, true action, values and done for each of the stored demostration
        obss, action_true, done = flat_batch[:,
                                             0], flat_batch[:,
                                                            1], flat_batch[:,
                                                                           2]
        action_true = torch.tensor([action for action in action_true],
                                   device=self.device,
                                   dtype=torch.long)

        # Memory to be stored
        memories = torch.zeros([len(flat_batch), acmodel.memory_size],
                               device=self.device)
        episode_ids = np.zeros(len(flat_batch))
        memory = torch.zeros([len(batch), acmodel.memory_size],
                             device=self.device)

        preprocessed_first_obs = self.obss_preprocessor(obss[inds],
                                                        device=self.device)
        instr_embedding = acmodel._get_instr_embedding(
            preprocessed_first_obs.instr)

        # Loop terminates when every observation in the flat_batch has been handled
        while True:
            # taking observations and done located at inds
            obs = obss[inds]
            done_step = done[inds]
            preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
            with torch.no_grad():
                # taking the memory till len(inds), as demos beyond that have already finished
                new_memory = acmodel(preprocessed_obs, memory[:len(inds), :],
                                     instr_embedding[:len(inds)])['memory']

            memories[inds, :] = memory[:len(inds), :]
            memory[:len(inds), :] = new_memory
            episode_ids[inds] = range(len(inds))

            # Updating inds, by removing those indices corresponding to which the demonstrations have finished
            inds = inds[:len(inds) - sum(done_step)]
            if len(inds) == 0:
                break

            # Incrementing the remaining indices
            inds = [index + 1 for index in inds]
        # Here, actual backprop upto args.recurrence happens
        final_loss = 0
        final_entropy, final_policy_loss, final_value_loss = 0, 0, 0

        indexes = self.starting_indexes(num_frames)
        memory = memories[indexes]
        accuracy = 0
        total_frames = len(indexes) * self.args.recurrence
        for _ in range(self.args.recurrence):
            obs = obss[indexes]
            preprocessed_obs = self.obss_preprocessor(obs, device=self.device)
            action_step = action_true[indexes]
            mask_step = mask[indexes]
            model_results = acmodel(preprocessed_obs, memory * mask_step,
                                    instr_embedding[episode_ids[indexes]])
            dist = model_results['dist']
            memory = model_results['memory']

            entropy = dist.entropy().mean()
            policy_loss = -dist.log_prob(action_step).mean()
            loss = policy_loss - self.args.entropy_coef * entropy
            action_pred = dist.probs.max(1, keepdim=True)[1]
            accuracy += float(
                (action_pred == action_step.unsqueeze(1)).sum()) / total_frames
            final_loss += loss
            final_entropy += entropy
            final_policy_loss += policy_loss
            indexes += 1

        final_loss /= self.args.recurrence

        # if is_training:
        #     self.optimizer.zero_grad()
        #     final_loss.backward()
        #     self.optimizer.step()

        log = {}
        log["entropy"] = float(final_entropy / self.args.recurrence)
        log["policy_loss"] = float(final_policy_loss / self.args.recurrence)
        log["accuracy"] = float(accuracy)
        return final_loss, log

    # def forward(self, x_spt, y_spt, x_qry, y_qry):
    def forward(self, demo):
        task_num = self.args.task_num

        losses = []  # losses_q[i], i is tasks idx
        logs = []
        grads = []
        self.optimizer.zero_grad()

        for i in range(task_num):

            # copy initializing net
            self.fast_net = copy.deepcopy(self.net)
            for p in self.fast_net.parameters():
                p.retain_grad()
            self.fast_net.zero_grad()

            # optimize fast net for k isntances of task i
            loss_task, log = self.forward_batch(demo, i, 'fast')
            # grad = torch.autograd.grad(loss_task, self.fast_net.parameters(),allow_unused = True)
            loss_task.backward()
            grad = [x.grad for x in self.fast_net.parameters()]
            # print (grad)
            grads.append(grad)
            # self.optimizer.step()
            # loss_task, log = self.forward_batch(demo, i, 'fast')
            # losses.append(loss_task)
            logs.append(log)

        self.meta_update(demo, grads)
        # end of all tasks
        # sum over all losses on query set across all tasks
        # loss_q = sum(losses) / task_num
        # # optimize theta parameters
        # self.meta_optim.zero_grad()

        # grad = torch.autograd.grad(loss_q, self.net.parameters(), allow_unused=True)
        # print (grad)
        # # loss_q.backward()
        # for g,p in zip(grad,self.net.parameters()):
        #     p.grad = g
        # # print('meta update')
        # # for p in self.net.parameters()[:5]:
        # # (torch.norm(p).item())
        # self.meta_optim.step()

        return logs

    def meta_update(self, demo, grads):
        print('\n Meta update \n')
        # We use a dummy forward / backward pass to get the correct grads into self.net
        loss, _ = self.forward_batch(demo, 0, 'net')
        gradients = []
        for p in self.net.parameters():
            gradients.append(torch.zeros(np.array(p.data).shape).cuda())
        # Unpack the list of grad dicts
        for i in range(len(grads[0])):
            for grad in grads:
                if grad[i] is not None:
                    gradients[i] = gradients[i] + grad[i][0]
        # gradients = [sum(grad[i][0] for grad in grads) for i in range(len(grads[0]))]
        # gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()}
        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []
        for i, p in enumerate(self.net.parameters()):

            def get_closure():
                it = i

                def replace_grad(grad):
                    ng = Variable(
                        torch.from_numpy(
                            np.array(gradients[it], dtype=np.float32))).cuda()
                    return ng

                return replace_grad

            try:
                hooks.append(p.register_hook(get_closure()))
            except:
                print(p)
                get_closure()
        # Compute grads for current step, replace with summed gradients as defined by hook
        self.meta_optim.zero_grad()
        loss.backward()
        # Update the net parameters with the accumulated gradient according to optimizer
        self.meta_optim.step()
        # Remove the hooks before next training phase
        for h in hooks:
            h.remove()

    def validate(self, demo):
        val_task_num = self.args.task_num

        losses = []  # losses_q[i], i is tasks idx
        logs = []
        val_logs = []
        for i in range(19):
            self.fast_net = copy.deepcopy(self.net)
            self.fast_net.zero_grad()

            # optimize fast net for k isntances of task i
            for k in range(5):
                loss_task, log = self.forward_batch(demo[32 * k:32 * k + 32],
                                                    119 - i, 'fast')

                self.optimizer.zero_grad()
                loss_task.backward()
                self.optimizer.step()
                # loss_task, log = self.forward_batch(demo, i, 'fast')
                # losses.append(loss_task)
                logs.append(log)
            loss_task, log = self.forward_batch(demo[32 * k:32 * k + 32],
                                                119 - i, 'fast')
            val_logs.append(log)

        return val_logs
Exemple #11
0
# Define obss preprocessor
if 'emb' in args.arch:
    obss_preprocessor = utils.IntObssPreprocessor(args.model, envs0[0].observation_space, args.pretrained_model)
else:
    obss_preprocessor = utils.ObssPreprocessor(args.model, envs0[0].observation_space, args.pretrained_model)

# Define actor-critic model
acmodel0 = utils.load_model(args.model, 0, raise_not_found=False)
acmodel1 = utils.load_model(args.model, 1, raise_not_found=False)
if acmodel0 is None:
    if args.pretrained_model:
        acmodel0 = utils.load_model(args.pretrained_model, 0, raise_not_found=True)
    else:
        #torch.manual_seed(args.seed)
        acmodel0 = ACModel(obss_preprocessor.obs_space, envs0[0].action_space,
                           args.image_dim, args.memory_dim, args.instr_dim, args.enc_dim, args.dec_dim,
                           not args.no_instr, args.instr_arch, not args.no_mem, args.arch,
                           args.len_message, args.num_symbols)
if acmodel1 is None:
    if args.pretrained_model:
        acmodel1 = utils.load_model(args.pretrained_model, 1, raise_not_found=True)
    else:
        #torch.manual_seed(args.seed)
        acmodel1 = ACModel(obss_preprocessor.obs_space, envs1[0].action_space,
                           args.image_dim, args.memory_dim, args.instr_dim, args.enc_dim, args.dec_dim,
                           not args.no_instr, args.instr_arch, not args.no_mem, args.arch,
                           args.len_message, args.num_symbols)

obss_preprocessor.vocab.save()
utils.save_model(acmodel0, args.model, 0)
utils.save_model(acmodel1, args.model, 1)