Exemple #1
0
    def sample(
            self,
            state: numpy.ndarray,
            *args,
            deterministic: bool = False,
            metric_writer: Writer = MockWriter(),
            **kwargs,
    ) -> Tuple[Any, ...]:
        """

@param state:
@param args:
@param deterministic:
@param metric_writer:
@param kwargs:
@return:
"""
        self._sample_i += 1
        self._sample_i_since_last_update += 1
        action = self._sample(
            state,
            *args,
            deterministic=deterministic,
            metric_writer=metric_writer,
            **kwargs,
        )

        if self._action_clipping.enabled:
            action = numpy.clip(action, self._action_clipping.low,
                                self._action_clipping.high)

        return action
Exemple #2
0
def test_invalid_val_type_scalars(tag, val, step):
    try:
        with MockWriter() as w:
            w.scalar(tag, val, step)
        assert False
    except Exception as e:
        assert True
Exemple #3
0
    def _update(self, metric_writer: Writer = MockWriter()) -> float:
        """

@param metric_writer:
@return:
"""
        transitions = self._prepare_transitions()

        accum_loss = mean_accumulator()
        for ith_inner_update in tqdm(range(self._num_inner_updates),
                                     desc="#Inner updates",
                                     leave=False):
            self.inner_update_i += 1
            loss, early_stop_inner = self.inner_update(
                *transitions, metric_writer=metric_writer)
            accum_loss.send(loss)

            if is_none_or_zero_or_negative_or_mod_zero(
                    self._update_target_interval, self.inner_update_i):
                self._update_targets(self._copy_percentage,
                                     metric_writer=metric_writer)

            if early_stop_inner:
                break

        mean_loss = next(accum_loss)

        if metric_writer:
            metric_writer.scalar("Inner Updates", ith_inner_update)
            metric_writer.scalar("Mean Loss", mean_loss)

        return mean_loss
Exemple #4
0
def kl_divergence(mean, log_var,
                  writer: Writer = MockWriter()) -> torch.Tensor:
    """

    Args:
      mean:
      log_var:

    Returns:

    """
    batch_size = mean.size(0)
    assert batch_size != 0

    if mean.data.ndimension() == 4:
        mean = mean.view(mean.size(0), mean.size(1))

    if log_var.data.ndimension() == 4:
        log_var = log_var.view(log_var.size(0), log_var.size(1))

    klds = -0.5 * (1 + log_var - mean.pow(2) - log_var.exp())

    if writer:
        writer.scalar("dimension_wise_kld", klds.mean(0))
        writer.scalar("mean_kld", klds.mean(1).mean(0, True))

    return klds.sum(1).mean(0, True)  # total_kld
Exemple #5
0
    def _update(self, *, metric_writer: Writer = MockWriter()) -> None:
        """

@param metric_writer:
@return:
"""

        loss_ = math.inf

        if self.update_i > self._initial_observation_period:
            if is_zero_or_mod_zero(self._learning_frequency, self.update_i):
                if len(self._memory_buffer) > self._batch_size:
                    transitions = self._memory_buffer.sample(self._batch_size)

                    td_error, Q_expected, Q_state = self._td_error(transitions)
                    td_error = td_error.detach().squeeze(-1).cpu().numpy()

                    if self._use_per:
                        self._memory_buffer.update_last_batch(td_error)

                    loss = self._loss_function(Q_state, Q_expected)

                    self._optimiser.zero_grad()
                    loss.backward()
                    self.post_process_gradients(self.value_model.parameters())
                    self._optimiser.step()

                    loss_ = to_scalar(loss)
                    if metric_writer:
                        metric_writer.scalar("td_error", td_error.mean(),
                                             self.update_i)
                        metric_writer.scalar("loss", loss_, self.update_i)

                    if self._scheduler:
                        self._scheduler.step()
                        if metric_writer:
                            for i, param_group in enumerate(
                                    self._optimiser.param_groups):
                                metric_writer.scalar(f"lr{i}",
                                                     param_group["lr"],
                                                     self.update_i)

                else:
                    logging.info(
                        "Batch size is larger than current memory size, skipping update"
                    )

            if self._double_dqn:
                if is_zero_or_mod_zero(self._sync_target_model_frequency,
                                       self.update_i):
                    update_target(
                        target_model=self._target_value_model,
                        source_model=self.value_model,
                        copy_percentage=self._copy_percentage,
                    )
                if metric_writer:
                    metric_writer.blip("Target Model Synced", self.update_i)

        return loss_
Exemple #6
0
def test_invalid_tag_scalars(tag, val, step):
    try:
        with MockWriter() as w:
            w.scalar(tag, val, step)
        assert False
    except Exception as e:
        print(e)
        assert True
Exemple #7
0
def test_global_writer():
    with MockWriter() as writer_o:
        mw2 = MockWriter()
        assert writer_o == global_writer()
        assert mw2 != global_writer()
        assert writer_o != mw2

        with MockWriter() as writer_i:
            assert writer_i == global_writer()
            assert writer_o != global_writer()
            assert writer_o != mw2

        assert writer_o == global_writer()
        assert writer_o != mw2

        set_global_writer(mw2)
        assert mw2 == global_writer()
        assert writer_o != global_writer()
Exemple #8
0
    def _update(self, *args, metric_writer: Writer = MockWriter(),
                **kwargs) -> Any:
        """

@param args:
@param metric_writer:
@param kwargs:
@return:
"""
        raise NotImplementedError
Exemple #9
0
  def __build__(
      self,
      observation_space: ObservationSpace,
      action_space: ActionSpace,
      signal_space: SignalSpace,
      metric_writer: Writer = MockWriter(),
      print_model_repr: bool = True,
      ) -> None:
    """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@return:
"""
    if action_space.is_discrete:
      raise ActionSpaceNotSupported(
          "discrete action space not supported in this implementation"
          )

    self._critic_arch_spec.kwargs["input_shape"] = (
        self._input_shape + self._output_shape
    )
    self._critic_arch_spec.kwargs["output_shape"] = 1

    self.critic_1 = self._critic_arch_spec().to(self._device)
    self.critic_1_target = copy.deepcopy(self.critic_1).to(self._device)
    freeze_model(self.critic_1_target, True, True)

    self.critic_2 = self._critic_arch_spec().to(self._device)
    self.critic_2_target = copy.deepcopy(self.critic_2).to(self._device)
    freeze_model(self.critic_2_target, True, True)

    self.critic_optimiser = self._critic_optimiser_spec(
        itertools.chain(self.critic_1.parameters(), self.critic_2.parameters())
        )

    self._actor_arch_spec.kwargs["input_shape"] = self._input_shape
    self._actor_arch_spec.kwargs["output_shape"] = self._output_shape
    self.actor = self._actor_arch_spec().to(self._device)
    self.actor_optimiser = self._actor_optimiser_spec(self.actor.parameters())

    if self._auto_tune_sac_alpha:
      self._target_entropy = -torch.prod(
          to_tensor(self._output_shape, device=self._device)
          ).item()
      self._log_sac_alpha = nn.Parameter(
          torch.log(to_tensor(self._sac_alpha, device=self._device)),
          requires_grad=True,
          )
      self.sac_alpha_optimiser = self._auto_tune_sac_alpha_optimiser_spec(
          [self._log_sac_alpha]
          )
Exemple #10
0
    def build(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        *,
        metric_writer: Writer = MockWriter(),
        print_model_repr: bool = True,
        verbose: bool = False,
        **kwargs,
    ) -> None:
        """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@param kwargs:
@return:
        :param verbose:
"""
        super().build(
            observation_space,
            action_space,
            signal_space,
            print_model_repr=print_model_repr,
            metric_writer=metric_writer,
            **kwargs,
        )

        if print_model_repr:
            for k, w in self.models.items():
                sprint(f"{k}: {w}", highlight=True, color="cyan")

                if metric_writer:
                    try:
                        model = copy.deepcopy(w).to("cpu")
                        dummy_input = model.sample_input()
                        sprint(f'{k} input: {dummy_input.shape}')

                        import contextlib

                        with contextlib.redirect_stdout(
                            None
                        ):  # So much useless frame info printed... Suppress it
                            if isinstance(metric_writer, GraphWriterMixin):
                                metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment...
                    except RuntimeError as ex:
                        sprint(
                            f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}",
                            color="red",
                            highlight=True,
                        )
Exemple #11
0
    def update(self, *args, metric_writer: Writer = MockWriter(),
               **kwargs) -> float:
        """

@param args:
@param metric_writer:
@param kwargs:
@return:
"""
        self._update_i += 1
        self._sample_i_since_last_update = 0
        return self._update(*args, metric_writer=metric_writer, **kwargs)
Exemple #12
0
    def _update(self, *, metric_writer: Writer = MockWriter()) -> None:
        """
Update

:return:
:rtype:
"""
        tensorised = TransitionPoint(*[
            to_tensor(a, device=self._device)
            for a in self._memory_buffer.sample()
        ])

        self._memory_buffer.clear()

        # Compute next Q value based on which action target actor would choose
        # Detach variable from the current graph since we don't want gradients for next Q to propagated
        with torch.no_grad():
            next_max_q = self._target_critic(
                tensorised.successor_state,
                self._target_actor(tensorised.state))
            Q_target = tensorised.signal + (self._discount_factor *
                                            next_max_q *
                                            tensorised.non_terminal_numerical)
            # Compute the target of the current Q values

        # Compute current Q value, critic takes state and action chosen
        td_error = self._critic_criteria(
            self._critic(tensorised.state, tensorised.action),
            Q_target.detach())
        self._critic_optimiser.zero_grad()
        td_error.backward()
        self.post_process_gradients(self._critic.parameters())
        self._critic_optimiser.step()

        with frozen_model(self._critic):
            policy_loss = -torch.mean(
                self._critic(tensorised.state, self._actor(tensorised.state)))
            self._actor_optimiser.zero_grad()
            policy_loss.backward()
            self.post_process_gradients(self._actor.parameters())
            self._actor_optimiser.step()

        if is_zero_or_mod_zero(self._update_target_interval, self.update_i):
            self.update_targets(self._copy_percentage,
                                metric_writer=metric_writer)

        if metric_writer:
            metric_writer.scalar("td_error", td_error.cpu().item())
            metric_writer.scalar("critic_loss", policy_loss.cpu().item())

        with torch.no_grad():
            return (td_error + policy_loss).cpu().item()
Exemple #13
0
    def _update(self, *, metric_writer=MockWriter()) -> float:
        """

:param metric_writer:

:returns:
"""

        if not len(self._memory_buffer) > 0:
            raise NoTrajectoryException

        trajectory = self._memory_buffer.retrieve_trajectory()
        self._memory_buffer.clear()

        log_probs = to_tensor(
            [
                self.get_log_prob(d, a)
                for d, a in zip(trajectory.distribution, trajectory.action)
            ],
            device=self._device,
        )

        signal = to_tensor(trajectory.signal, device=self._device)
        non_terminal = to_tensor(non_terminal_numerical_mask(
            trajectory.terminated),
                                 device=self._device)

        discounted_signal = discount_rollout_signal_torch(
            signal,
            self._discount_factor,
            device=self._device,
            non_terminal=non_terminal,
        )

        loss = -(log_probs * discounted_signal).mean()

        self._optimiser.zero_grad()
        loss.backward()
        self.post_process_gradients(self.distributional_regressor.parameters())
        self._optimiser.step()

        if self._scheduler:
            self._scheduler.step()
            if metric_writer:
                for i, param_group in enumerate(self._optimiser.param_groups):
                    metric_writer.scalar(f"lr{i}", param_group["lr"])

        loss_cpu = loss.detach().to("cpu").numpy()
        if metric_writer:
            metric_writer.scalar("Loss", loss_cpu)

        return loss_cpu.item()
Exemple #14
0
    def __build__(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        metric_writer: Writer = MockWriter(),
        print_model_repr: bool = True,
        *,
        distributional_regressor: Module = None,
        optimiser: Optimizer = None,
    ) -> None:
        """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@param distributional_regressor:
@param optimiser:
@return:
"""

        if distributional_regressor:
            self.distributional_regressor = distributional_regressor
        else:
            self._policy_arch_spec.kwargs["input_shape"] = self._input_shape
            if action_space.is_discrete:
                self._policy_arch_spec = GDKC(
                    constructor=CategoricalMLP,
                    kwargs=self._policy_arch_spec.kwargs)
            else:
                self._policy_arch_spec = GDKC(
                    constructor=MultiDimensionalNormalMLP,
                    kwargs=self._policy_arch_spec.kwargs,
                )

            self._policy_arch_spec.kwargs["output_shape"] = self._output_shape

            self.distributional_regressor: Module = self._policy_arch_spec(
            ).to(self._device)

        if optimiser:
            self._optimiser = optimiser
        else:
            self._optimiser = self._optimiser_spec(
                self.distributional_regressor.parameters())

        if self._scheduler_spec:
            self._scheduler = self._scheduler_spec(self._optimiser)
        else:
            self._scheduler = None
Exemple #15
0
    def __call__(
            self,
            *,
            iterations: int = 1000,
            render_frequency: int = 100,
            stat_frequency: int = 10,
            disable_stdout: bool = False,
            metric_writer: Writer = MockWriter(),
            **kwargs,
    ):
        r"""
:param log_directory:
:param disable_stdout: Whether to disable stdout statements or not
:type disable_stdout: bool
:param iterations: How many iterations to train for
:type iterations: int
:param render_frequency: How often to render environment
:type render_frequency: int
:param stat_frequency: How often to write statistics
:type stat_frequency: int
:return: A training resume containing the trained agents models and some statistics
:rtype: TR
"""

        E = range(1, iterations)
        E = tqdm(E, desc="Rollout #", leave=False)

        best_episode_return = -math.inf
        for episode_i in E:
            initial_state = self.environment.reset()

            kwargs.update(render_environment=is_positive_and_mod_zero(
                render_frequency, episode_i))
            ret, *_ = rollout_on_policy(
                self.agent,
                initial_state,
                self.environment,
                rollout_íth=episode_i,
                metric_writer=is_positive_and_mod_zero(stat_frequency,
                                                       episode_i,
                                                       ret=metric_writer),
                disable_stdout=disable_stdout,
                **kwargs,
            )

            if best_episode_return < ret:
                best_episode_return = ret
                self.call_on_improvement_callbacks(**kwargs)

            if self.early_stop:
                break
Exemple #16
0
    def __build__(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        metric_writer: Writer = MockWriter(),
        print_model_repr: bool = True,
    ) -> None:
        """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@param critic:
@param critic_optimiser:
@param actor:
@param actor_optimiser:
@return:
"""

        if action_space.is_discrete:
            raise ActionSpaceNotSupported()

        self._actor_arch_spec.kwargs["input_shape"] = self._input_shape
        self._actor_arch_spec.kwargs["output_shape"] = self._output_shape
        self._actor = self._actor_arch_spec().to(self._device)
        self._target_actor = copy.deepcopy(self._actor).to(self._device)
        freeze_model(self._target_actor, True, True)
        self._actor_optimiser = self._actor_optimiser_spec(
            self._actor.parameters())

        self._critic_arch_spec.kwargs["input_shape"] = (
            *self._input_shape,
            *self._output_shape,
        )
        self._critic_arch_spec.kwargs["output_shape"] = 1
        self._critic = self._critic_arch_spec().to(self._device)
        self._target_critic = copy.deepcopy(self._critic).to(self._device)
        freeze_model(self._target_critic, True, True)
        self._critic_optimiser = self._critic_optimiser_spec(
            self._critic.parameters())

        self._random_process = self._random_process_spec(
            sigma=mean([r.span for r in action_space.ranges]))
Exemple #17
0
    def _sample(self,
                state: EnvironmentSnapshot,
                *args,
                deterministic: bool = False,
                metric_writer: Writer = MockWriter(),
                **kwargs) -> Any:
        """

@param state:
@param args:
@param deterministic:
@param metric_writer:
@param kwargs:
@return:
"""
        self._sample_i_since_last_update += 1
        return self.action_space.sample()
Exemple #18
0
    def _sample(
            self,
            state: numpy.ndarray,
            *args,
            deterministic: bool = False,
            metric_writer: Writer = MockWriter(),
            **kwargs,
    ) -> Tuple[Any, ...]:
        """

@param state:
@param args:
@param deterministic:
@param metric_writer:
@param kwargs:
@return:
"""
        raise NotImplementedError
Exemple #19
0
    def _sample(
            self,
            state: Sequence,
            deterministic: bool = False,
            metric_writer: Writer = MockWriter(),
    ) -> numpy.ndarray:
        """

@param state:
@param deterministic:
@param metric_writer:
@return:
"""
        if not deterministic and self._exploration_sample(
                self._sample_i, metric_writer):
            return self._sample_random_process(state)

        return self._sample_model(state)
Exemple #20
0
  def _sample(
      self,
      state: Any,
      *args,
      deterministic: bool = False,
      metric_writer: Writer = MockWriter()
      ) -> Tuple[torch.Tensor, Any]:
    """

@param state:
@param args:
@param deterministic:
@param metric_writer:
@param kwargs:
@return:
"""
    distribution = self.actor(to_tensor(state, device=self._device))

    with torch.no_grad():
      return (torch.tanh(distribution.sample().detach()), distribution)
    def update(self,
               trajectories,
               *args,
               metric_writer: Writer = MockWriter(),
               attempts=5,
               **kwargs) -> Any:
        """

    Fit the linear baseline model (signal estimator) with the provided paths via damped least squares


    @param trajectories:
    @type trajectories:
    @param args:
    @type args:
    @param metric_writer:
    @type metric_writer:
    @param attempts:
    @type attempts:
    @param kwargs:
    @type kwargs:
    @return:
    @rtype:
    """
        features_matrix = numpy.concatenate(
            [self.extract_features(trajectory) for trajectory in trajectories])
        returns_matrix = numpy.concatenate(
            [trajectory["returns"] for trajectory in trajectories])
        # returns_matrix = numpy.concatenate([path.returns for path in states])
        c_regularisation_coeff = self._l2_reg_coefficient
        id_fm = numpy.identity(features_matrix.shape[1])
        for _ in range(attempts):
            self._linear_coefficients = numpy.linalg.lstsq(
                features_matrix.T.dot(features_matrix) +
                c_regularisation_coeff * id_fm,
                features_matrix.T.dot(returns_matrix),
                rcond=-1,
            )[0]
            if not numpy.any(numpy.isnan(self._linear_coefficients)):
                break  # Non-Nan solution found
            c_regularisation_coeff *= 10
Exemple #22
0
    def train_once(self,
                   iteration_number: int,
                   trajectories: Sequence,
                   *,
                   writer: Writer = MockWriter()):
        """Perform one step of policy optimization given one batch of samples.
    Args:
        iteration_number (int): Iteration number.
        trajectories (list[dict]): A list of collected paths.
    Returns:
        float: The average return in last epoch cycle.
        @param writer:
        @type writer:
    """

        undiscounted_returns = []
        for trajectory in TrajectoryBatch.from_trajectory_list(
                self._env_spec, trajectories).split():  # TODO: EEEEW
            undiscounted_returns.append(sum(trajectory.rewards))

        sample_returns = np.mean(undiscounted_returns)
        self._all_returns.append(sample_returns)

        epoch = iteration_number // self._num_candidate_policies
        i_sample = iteration_number - epoch * self._num_candidate_policies
        writer.scalar("Epoch", epoch)
        writer.scalar("# Sample", i_sample)

        if (
                iteration_number + 1
        ) % self._num_candidate_policies == 0:  # When looped all the way around update shared parameters, WARNING RACE CONDITIONS!
            sample_returns = max(self._all_returns)
            self.update()

        self.policy.set_param_values(
            self._shared_params[(i_sample + 1) % self._num_candidate_policies])

        return sample_returns
    def sample(self,
               state: EnvironmentSnapshot,
               *args,
               metric_writer: Writer = MockWriter(),
               **kwargs) -> Any:
        """

    samples signal

    @param state:
    @type state:
    @param args:
    @type args:
    @param metric_writer:
    @type metric_writer:
    @param kwargs:
    @type kwargs:
    @return:
    @rtype:
    """
        if self._linear_coefficients is None:
            return numpy.zeros(len(state["rewards"]))
        return self.extract_features(state).dot(self._linear_coefficients)
Exemple #24
0
  def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float:
    """

@param args:
@param metric_writer:
@param kwargs:
@return:
"""
    accum_loss = 0
    for ith_inner_update in tqdm(
        range(self._num_inner_updates), desc="Inner update #", leave=False, postfix=f"Agent update #{self.update_i}"
        ):
      self.inner_update_i += 1
      batch = self._memory_buffer.sample(self._batch_size)
      tensorised = TransitionPoint(
          *[to_tensor(a, device=self._device) for a in batch]
          )

      with frozen_parameters(self.actor.parameters()):
        accum_loss += self.update_critics(
            tensorised, metric_writer=metric_writer
            )

      with frozen_parameters(
          itertools.chain(self.critic_1.parameters(), self.critic_2.parameters())
          ):
        accum_loss += self.update_actor(tensorised, metric_writer=metric_writer)

      if is_zero_or_mod_zero(self._target_update_interval, self.inner_update_i):
        self.update_targets(self._copy_percentage, metric_writer=metric_writer)

    if metric_writer:
      metric_writer.scalar("Accum_loss", accum_loss, self.update_i)
      metric_writer.scalar("num_inner_updates_i", ith_inner_update, self.update_i)

    return accum_loss
Exemple #25
0
    def __build__(
        self,
        observation_space: ObservationSpace,
        action_space: ActionSpace,
        signal_space: SignalSpace,
        metric_writer: Writer = MockWriter(),
        print_model_repr: bool = True,
    ) -> None:
        """

@param observation_space:
@param action_space:
@param signal_space:
@param metric_writer:
@param print_model_repr:
@return:
"""
        if action_space.is_mixed:
            raise ActionSpaceNotSupported()
        elif action_space.is_continuous:
            self._continuous_arch_spec.kwargs[
                "input_shape"] = self._input_shape
            self._continuous_arch_spec.kwargs[
                "output_shape"] = self._output_shape
            self.actor_critic = self._continuous_arch_spec().to(self._device)
        else:
            self._discrete_arch_spec.kwargs["input_shape"] = self._input_shape
            self._discrete_arch_spec.kwargs[
                "output_shape"] = self._output_shape
            self.actor_critic = self._discrete_arch_spec().to(self._device)

        self._target_actor_critic = copy.deepcopy(self.actor_critic).to(
            self._device)
        freeze_model(self._target_actor_critic, True, True)

        self._optimiser = self._optimiser_spec(self.actor_critic.parameters())
Exemple #26
0
def loss_function(
        reconstruction,
        original,
        mean,
        log_var,
        beta: Number = 1,
        writer: Writer = MockWriter(),
):
    """

    Args:
      reconstruction:
      original:
      mean:
      log_var:
      beta:

    Returns:

    """

    total_kld = kl_divergence(mean, log_var, writer)

    if True:
        beta_vae_loss = beta * total_kld
    else:
        beta_vae_loss = (
            recon_loss + self.gamma * (
                total_kld - torch.clamp(
                    self.C_max / self.C_stop_iter * self.global_iter,
                    0,
                    self.C_max.data[0],
                )  # C
            ).abs())

    return reconstruction_loss(reconstruction, original) + beta_vae_loss
Exemple #27
0
def train_siamese(
    model,
    optimiser,
    criterion,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs,
    data_dir,
    train_batch_size,
    model_name,
    save_path,
    save_best=False,
    img_size,
    validation_interval: int = 1,
):
    """
    :param data_dir:
    :type data_dir:
    :param optimiser:
    :type optimiser:
    :param criterion:
    :type criterion:
    :param writer:
    :type writer:
    :param model_name:
    :type model_name:
    :param save_path:
    :type save_path:
    :param save_best:
    :type save_best:
    :param model:
    :type model:
    :param train_number_epochs:
    :type train_number_epochs:
    :param train_batch_size:
    :type train_batch_size:
    :return:
    :rtype:

      Parameters
      ----------
      img_size
      validation_interval"""

    train_dataloader = DataLoader(
        TripletDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=SplitEnum.training,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        TripletDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=SplitEnum.validation,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                optimiser.zero_grad()
                loss_contrastive = criterion(*model(
                    *[t.to(global_torch_device()) for t in tss]))
                loss_contrastive.backward()
                optimiser.step()
                a = loss_contrastive.cpu().item()
                writer.scalar("train_loss", a, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        o = model(*[t.to(global_torch_device()) for t in tsv])
                        a_v = criterion(*o).cpu().item()
                        valid_positive_acc = (accuracy(
                            distances=pairwise_distance(o[0], o[1]),
                            is_diff=0).cpu().item())
                        valid_negative_acc = (accuracy(
                            distances=pairwise_distance(o[0], o[2]),
                            is_diff=1).cpu().item())
                        valid_acc = numpy.mean(
                            (valid_negative_acc, valid_positive_acc))
                        writer.scalar("valid_loss", a_v, batch_i)
                        writer.scalar("valid_positive_acc", valid_positive_acc,
                                      batch_i)
                        writer.scalar("valid_negative_acc", valid_negative_acc,
                                      batch_i)
                        writer.scalar("valid_acc", valid_acc, batch_i)
                        if a_v < best:
                            best = a_v
                            print(f"new best {best}")
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {a}, valid loss {a_v}, valid acc {valid_acc}"
            )

    return model
Exemple #28
0
from draugr.python_utilities.business import busy_indicator
from draugr.writers import LogWriter, MockWriter, Writer
from heimdallr import PROJECT_APP_PATH, PROJECT_NAME
from heimdallr.configuration.heimdallr_config import ALL_CONSTANTS
from heimdallr.configuration.heimdallr_settings import (
    HeimdallrSettings,
    SettingScopeEnum,
)
from heimdallr.utilities.gpu_utilities import pull_gpu_info
from warg import NOD

HOSTNAME = socket.gethostname()

__all__ = ["main"]

LOG_WRITER: Writer = MockWriter()


def on_publish(client, userdata, result) -> None:
    """ """
    global LOG_WRITER
    LOG_WRITER(result)


def on_disconnect(client, userdata, rc):
    """ """
    if rc != 0:
        print("Unexpected MQTT disconnection. Will auto-reconnect")
        client.reconnect()

def train_siamese(
    model: Module,
    optimiser: Optimizer,
    criterion: callable,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs: int,
    data_dir: Path,
    train_batch_size: int,
    model_name: str,
    save_path: Path,
    save_best: bool = False,
    img_size: Tuple[int, int],
    validation_interval: int = 1,
):
    """
:param img_size:
:type img_size:
:param validation_interval:
:type validation_interval:
:param data_dir:
:type data_dir:
:param optimiser:
:type optimiser:
:param criterion:
:type criterion:
:param writer:
:type writer:
:param model_name:
:type model_name:
:param save_path:
:type save_path:
:param save_best:
:type save_best:
:param model:
:type model:
:param train_number_epochs:
:type train_number_epochs:
:param train_batch_size:
:type train_batch_size:
:return:
:rtype:
"""

    train_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Training,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Validation,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                o = [t.to(global_torch_device()) for t in tss]
                optimiser.zero_grad()
                loss_contrastive = criterion(model(*o[:2]),
                                             o[2].to(dtype=torch.float))
                loss_contrastive.backward()
                optimiser.step()
                train_loss = loss_contrastive.cpu().item()
                writer.scalar("train_loss", train_loss, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        ov = [t.to(global_torch_device()) for t in tsv]
                        v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float)
                        valid_loss = criterion(v_o, fact).cpu().item()
                        valid_accuracy = (accuracy(distances=v_o,
                                                   is_diff=fact).cpu().item())
                        writer.scalar("valid_loss", valid_loss, batch_i)
                        if valid_loss < best:
                            best = valid_loss
                            print(f"new best {best}")
                            writer.blip("new_best", batch_i)
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}"
            )

    return model
Exemple #30
0
    def __call__(self,
                 *,
                 num_environment_steps=500000,
                 batch_size=128,
                 stat_frequency=10,
                 render_frequency=10000,
                 initial_observation_period=1000,
                 render_duration=1000,
                 update_agent_frequency: int = 1,
                 disable_stdout: bool = False,
                 train_agent: bool = True,
                 metric_writer: Writer = MockWriter(),
                 rollout_drawer: MplDrawer = MockDrawer(),
                 **kwargs) -> None:
        """

:param log_directory:
:param num_environment_steps:
:param stat_frequency:
:param render_frequency:
:param disable_stdout:
:return:
"""

        state = self.agent.extract_features(self.environment.reset())

        running_signal = mean_accumulator()
        best_running_signal = None
        running_mean_action = mean_accumulator()
        termination_i = 0
        signal_since_last_termination = 0
        duration_since_last_termination = 0

        for step_i in tqdm(range(num_environment_steps),
                           desc="Step #",
                           leave=False):

            sample = self.agent.sample(state)
            action = self.agent.extract_action(sample)

            snapshot = self.environment.react(action)
            successor_state = self.agent.extract_features(snapshot)
            signal = self.agent.extract_signal(snapshot)
            terminated = snapshot.terminated

            if train_agent:
                self.agent.remember(
                    state=state,
                    signal=signal,
                    terminated=terminated,
                    sample=sample,
                    successor_state=successor_state,
                )

            state = successor_state

            duration_since_last_termination += 1
            mean_signal = signal.mean().item()
            signal_since_last_termination += mean_signal

            running_mean_action.send(action.mean())
            running_signal.send(mean_signal)

            if (train_agent and is_positive_and_mod_zero(
                    update_agent_frequency * batch_size, step_i)
                    and len(self.agent.memory_buffer) > batch_size
                    and step_i > initial_observation_period):
                loss = self.agent.update(metric_writer=metric_writer)

                sig = next(running_signal)
                if not best_running_signal or sig > best_running_signal:
                    best_running_signal = sig
                    self.call_on_improvement_callbacks(loss=loss,
                                                       signal=sig,
                                                       **kwargs)

            if terminated.any():
                termination_i += 1
                if metric_writer:
                    metric_writer.scalar(
                        "duration_since_last_termination",
                        duration_since_last_termination,
                    )
                    metric_writer.scalar("signal_since_last_termination",
                                         signal_since_last_termination)
                    metric_writer.scalar("running_mean_action",
                                         next(running_mean_action))
                    metric_writer.scalar("running_signal",
                                         next(running_signal))
                signal_since_last_termination = 0
                duration_since_last_termination = 0

            if (is_zero_or_mod_below(render_frequency, render_duration, step_i)
                    and render_frequency != 0):
                self.environment.render()
                if rollout_drawer:
                    rollout_drawer.draw(action)

            if self.early_stop:
                break