Example #1
0
File: adam.py Project: yulkang/vel
    def instantiate(self, model: Model) -> torch.optim.Adam:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())

            if isinstance(self.lr, collections.Sequence):
                for idx, lr in enumerate(self.lr):
                    parameters[idx]['lr'] = lr

                default_lr = self.lr[0]
            else:
                default_lr = float(self.lr)

            if isinstance(self.weight_decay, collections.Sequence):
                for idx, weight_decay in enumerate(self.weight_decay):
                    parameters[idx]['weight_decay'] = weight_decay

                default_weight_decay = self.weight_decay[0]
            else:
                default_weight_decay = self.weight_decay

            return torch.optim.Adam(parameters,
                                    lr=default_lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=default_weight_decay,
                                    amsgrad=self.amsgrad)
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

            return torch.optim.Adam(parameters,
                                    lr=self.lr,
                                    betas=self.betas,
                                    eps=self.eps,
                                    weight_decay=self.weight_decay,
                                    amsgrad=self.amsgrad)
Example #2
0
    def checkpoint(self, epoch_info: EpochInfo, model: Model):
        """ When epoch is done, we persist the training state """
        self.clean(epoch_info.global_epoch_idx - 1)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(), self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = epoch_info.state_dict()
        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(hidden_state, self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                                                   epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(model.state_dict(), self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)
Example #3
0
    def rollout(self, batch_info: BatchInfo, model: Model, number_of_steps: int) -> Rollout:
        """ Calculate env rollout """
        accumulator = TensorAccumulator()
        episode_information = []  # List of dictionaries with episode information

        if self.hidden_state is None and model.is_recurrent:
            self.hidden_state = model.zero_state(self.last_observation.size(0)).to(self.device)

        # Remember rollout initial state, we'll use that for training as well
        initial_hidden_state = self.hidden_state

        for step_idx in range(number_of_steps):
            if model.is_recurrent:
                step = model.step(self.last_observation.to(self.device), state=self.hidden_state)
                self.hidden_state = step['state']
            else:
                step = model.step(self.last_observation.to(self.device))

            # Add step to the tensor accumulator
            for name, tensor in step.items():
                accumulator.add(name, tensor.cpu())

            accumulator.add('observations', self.last_observation)

            actions_numpy = step['actions'].detach().cpu().numpy()
            new_obs, new_rewards, new_dones, new_infos = self.environment.step(actions_numpy)

            # Done is flagged true when the episode has ended AND the frame we see is already a first frame from the
            # next episode
            dones_tensor = torch.from_numpy(new_dones.astype(np.float32)).clone()
            accumulator.add('dones', dones_tensor)

            self.last_observation = torch.from_numpy(new_obs).clone()

            if model.is_recurrent:
                # Zero out state in environments that have finished
                self.hidden_state = self.hidden_state * (1.0 - dones_tensor.unsqueeze(-1)).to(self.device)

            accumulator.add('rewards', torch.from_numpy(new_rewards.astype(np.float32)).clone())

            episode_information.append(new_infos)

        if model.is_recurrent:
            final_values = model.value(self.last_observation.to(self.device), state=self.hidden_state).cpu()
        else:
            final_values = model.value(self.last_observation.to(self.device)).cpu()

        accumulated_tensors = accumulator.result()

        return Trajectories(
            num_steps=accumulated_tensors['observations'].size(0),
            num_envs=accumulated_tensors['observations'].size(1),
            environment_information=episode_information,
            transition_tensors=accumulated_tensors,
            rollout_tensors={
                'initial_hidden_state': initial_hidden_state.cpu() if initial_hidden_state is not None else None,
                'final_values': final_values
            }
        )
Example #4
0
    def instantiate(self, model: Model) -> torch.optim.SGD:
        if self.layer_groups:
            parameters = mu.to_parameter_groups(model.get_layer_groups())
        else:
            parameters = filter(lambda p: p.requires_grad, model.parameters())

        return torch.optim.SGD(
            parameters,
            lr=self.lr, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay,
            nesterov=self.nesterov
        )
Example #5
0
 def instantiate(self, model: Model) -> torch.optim.Adadelta:
     return torch.optim.Adadelta(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=self.lr,
                                 rho=self.rho,
                                 eps=self.eps,
                                 weight_decay=self.weight_decay)
Example #6
0
 def instantiate(self, model: Model) -> RMSpropTF:
     return RMSpropTF(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=self.lr,
                      alpha=self.alpha,
                      eps=self.eps,
                      weight_decay=self.weight_decay,
                      momentum=self.momentum,
                      centered=self.centered)
    def __init__(self, device: torch.device, settings: BufferedOffPolicyIterationReinforcerSettings,
                 environment: VecEnv, model: Model, algo: AlgoBase, env_roller: ReplayEnvRollerBase):
        self.device = device
        self.settings = settings
        self.environment = environment

        self._trained_model = model.to(self.device)
        self.algo = algo

        self.env_roller = env_roller
    def __init__(self, device: torch.device,
                 settings: OnPolicyIterationReinforcerSettings, model: Model,
                 algo: AlgoBase, env_roller: EnvRollerBase) -> None:
        self.device = device
        self.settings = settings

        self._trained_model = model.to(self.device)

        self.env_roller = env_roller
        self.algo = algo