def sample( self, state: numpy.ndarray, *args, deterministic: bool = False, metric_writer: Writer = MockWriter(), **kwargs, ) -> Tuple[Any, ...]: """ @param state: @param args: @param deterministic: @param metric_writer: @param kwargs: @return: """ self._sample_i += 1 self._sample_i_since_last_update += 1 action = self._sample( state, *args, deterministic=deterministic, metric_writer=metric_writer, **kwargs, ) if self._action_clipping.enabled: action = numpy.clip(action, self._action_clipping.low, self._action_clipping.high) return action
def test_invalid_val_type_scalars(tag, val, step): try: with MockWriter() as w: w.scalar(tag, val, step) assert False except Exception as e: assert True
def _update(self, metric_writer: Writer = MockWriter()) -> float: """ @param metric_writer: @return: """ transitions = self._prepare_transitions() accum_loss = mean_accumulator() for ith_inner_update in tqdm(range(self._num_inner_updates), desc="#Inner updates", leave=False): self.inner_update_i += 1 loss, early_stop_inner = self.inner_update( *transitions, metric_writer=metric_writer) accum_loss.send(loss) if is_none_or_zero_or_negative_or_mod_zero( self._update_target_interval, self.inner_update_i): self._update_targets(self._copy_percentage, metric_writer=metric_writer) if early_stop_inner: break mean_loss = next(accum_loss) if metric_writer: metric_writer.scalar("Inner Updates", ith_inner_update) metric_writer.scalar("Mean Loss", mean_loss) return mean_loss
def kl_divergence(mean, log_var, writer: Writer = MockWriter()) -> torch.Tensor: """ Args: mean: log_var: Returns: """ batch_size = mean.size(0) assert batch_size != 0 if mean.data.ndimension() == 4: mean = mean.view(mean.size(0), mean.size(1)) if log_var.data.ndimension() == 4: log_var = log_var.view(log_var.size(0), log_var.size(1)) klds = -0.5 * (1 + log_var - mean.pow(2) - log_var.exp()) if writer: writer.scalar("dimension_wise_kld", klds.mean(0)) writer.scalar("mean_kld", klds.mean(1).mean(0, True)) return klds.sum(1).mean(0, True) # total_kld
def _update(self, *, metric_writer: Writer = MockWriter()) -> None: """ @param metric_writer: @return: """ loss_ = math.inf if self.update_i > self._initial_observation_period: if is_zero_or_mod_zero(self._learning_frequency, self.update_i): if len(self._memory_buffer) > self._batch_size: transitions = self._memory_buffer.sample(self._batch_size) td_error, Q_expected, Q_state = self._td_error(transitions) td_error = td_error.detach().squeeze(-1).cpu().numpy() if self._use_per: self._memory_buffer.update_last_batch(td_error) loss = self._loss_function(Q_state, Q_expected) self._optimiser.zero_grad() loss.backward() self.post_process_gradients(self.value_model.parameters()) self._optimiser.step() loss_ = to_scalar(loss) if metric_writer: metric_writer.scalar("td_error", td_error.mean(), self.update_i) metric_writer.scalar("loss", loss_, self.update_i) if self._scheduler: self._scheduler.step() if metric_writer: for i, param_group in enumerate( self._optimiser.param_groups): metric_writer.scalar(f"lr{i}", param_group["lr"], self.update_i) else: logging.info( "Batch size is larger than current memory size, skipping update" ) if self._double_dqn: if is_zero_or_mod_zero(self._sync_target_model_frequency, self.update_i): update_target( target_model=self._target_value_model, source_model=self.value_model, copy_percentage=self._copy_percentage, ) if metric_writer: metric_writer.blip("Target Model Synced", self.update_i) return loss_
def test_invalid_tag_scalars(tag, val, step): try: with MockWriter() as w: w.scalar(tag, val, step) assert False except Exception as e: print(e) assert True
def test_global_writer(): with MockWriter() as writer_o: mw2 = MockWriter() assert writer_o == global_writer() assert mw2 != global_writer() assert writer_o != mw2 with MockWriter() as writer_i: assert writer_i == global_writer() assert writer_o != global_writer() assert writer_o != mw2 assert writer_o == global_writer() assert writer_o != mw2 set_global_writer(mw2) assert mw2 == global_writer() assert writer_o != global_writer()
def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> Any: """ @param args: @param metric_writer: @param kwargs: @return: """ raise NotImplementedError
def __build__( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @return: """ if action_space.is_discrete: raise ActionSpaceNotSupported( "discrete action space not supported in this implementation" ) self._critic_arch_spec.kwargs["input_shape"] = ( self._input_shape + self._output_shape ) self._critic_arch_spec.kwargs["output_shape"] = 1 self.critic_1 = self._critic_arch_spec().to(self._device) self.critic_1_target = copy.deepcopy(self.critic_1).to(self._device) freeze_model(self.critic_1_target, True, True) self.critic_2 = self._critic_arch_spec().to(self._device) self.critic_2_target = copy.deepcopy(self.critic_2).to(self._device) freeze_model(self.critic_2_target, True, True) self.critic_optimiser = self._critic_optimiser_spec( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) ) self._actor_arch_spec.kwargs["input_shape"] = self._input_shape self._actor_arch_spec.kwargs["output_shape"] = self._output_shape self.actor = self._actor_arch_spec().to(self._device) self.actor_optimiser = self._actor_optimiser_spec(self.actor.parameters()) if self._auto_tune_sac_alpha: self._target_entropy = -torch.prod( to_tensor(self._output_shape, device=self._device) ).item() self._log_sac_alpha = nn.Parameter( torch.log(to_tensor(self._sac_alpha, device=self._device)), requires_grad=True, ) self.sac_alpha_optimiser = self._auto_tune_sac_alpha_optimiser_spec( [self._log_sac_alpha] )
def build( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, *, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, verbose: bool = False, **kwargs, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @param kwargs: @return: :param verbose: """ super().build( observation_space, action_space, signal_space, print_model_repr=print_model_repr, metric_writer=metric_writer, **kwargs, ) if print_model_repr: for k, w in self.models.items(): sprint(f"{k}: {w}", highlight=True, color="cyan") if metric_writer: try: model = copy.deepcopy(w).to("cpu") dummy_input = model.sample_input() sprint(f'{k} input: {dummy_input.shape}') import contextlib with contextlib.redirect_stdout( None ): # So much useless frame info printed... Suppress it if isinstance(metric_writer, GraphWriterMixin): metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment... except RuntimeError as ex: sprint( f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}", color="red", highlight=True, )
def update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float: """ @param args: @param metric_writer: @param kwargs: @return: """ self._update_i += 1 self._sample_i_since_last_update = 0 return self._update(*args, metric_writer=metric_writer, **kwargs)
def _update(self, *, metric_writer: Writer = MockWriter()) -> None: """ Update :return: :rtype: """ tensorised = TransitionPoint(*[ to_tensor(a, device=self._device) for a in self._memory_buffer.sample() ]) self._memory_buffer.clear() # Compute next Q value based on which action target actor would choose # Detach variable from the current graph since we don't want gradients for next Q to propagated with torch.no_grad(): next_max_q = self._target_critic( tensorised.successor_state, self._target_actor(tensorised.state)) Q_target = tensorised.signal + (self._discount_factor * next_max_q * tensorised.non_terminal_numerical) # Compute the target of the current Q values # Compute current Q value, critic takes state and action chosen td_error = self._critic_criteria( self._critic(tensorised.state, tensorised.action), Q_target.detach()) self._critic_optimiser.zero_grad() td_error.backward() self.post_process_gradients(self._critic.parameters()) self._critic_optimiser.step() with frozen_model(self._critic): policy_loss = -torch.mean( self._critic(tensorised.state, self._actor(tensorised.state))) self._actor_optimiser.zero_grad() policy_loss.backward() self.post_process_gradients(self._actor.parameters()) self._actor_optimiser.step() if is_zero_or_mod_zero(self._update_target_interval, self.update_i): self.update_targets(self._copy_percentage, metric_writer=metric_writer) if metric_writer: metric_writer.scalar("td_error", td_error.cpu().item()) metric_writer.scalar("critic_loss", policy_loss.cpu().item()) with torch.no_grad(): return (td_error + policy_loss).cpu().item()
def _update(self, *, metric_writer=MockWriter()) -> float: """ :param metric_writer: :returns: """ if not len(self._memory_buffer) > 0: raise NoTrajectoryException trajectory = self._memory_buffer.retrieve_trajectory() self._memory_buffer.clear() log_probs = to_tensor( [ self.get_log_prob(d, a) for d, a in zip(trajectory.distribution, trajectory.action) ], device=self._device, ) signal = to_tensor(trajectory.signal, device=self._device) non_terminal = to_tensor(non_terminal_numerical_mask( trajectory.terminated), device=self._device) discounted_signal = discount_rollout_signal_torch( signal, self._discount_factor, device=self._device, non_terminal=non_terminal, ) loss = -(log_probs * discounted_signal).mean() self._optimiser.zero_grad() loss.backward() self.post_process_gradients(self.distributional_regressor.parameters()) self._optimiser.step() if self._scheduler: self._scheduler.step() if metric_writer: for i, param_group in enumerate(self._optimiser.param_groups): metric_writer.scalar(f"lr{i}", param_group["lr"]) loss_cpu = loss.detach().to("cpu").numpy() if metric_writer: metric_writer.scalar("Loss", loss_cpu) return loss_cpu.item()
def __build__( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, *, distributional_regressor: Module = None, optimiser: Optimizer = None, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @param distributional_regressor: @param optimiser: @return: """ if distributional_regressor: self.distributional_regressor = distributional_regressor else: self._policy_arch_spec.kwargs["input_shape"] = self._input_shape if action_space.is_discrete: self._policy_arch_spec = GDKC( constructor=CategoricalMLP, kwargs=self._policy_arch_spec.kwargs) else: self._policy_arch_spec = GDKC( constructor=MultiDimensionalNormalMLP, kwargs=self._policy_arch_spec.kwargs, ) self._policy_arch_spec.kwargs["output_shape"] = self._output_shape self.distributional_regressor: Module = self._policy_arch_spec( ).to(self._device) if optimiser: self._optimiser = optimiser else: self._optimiser = self._optimiser_spec( self.distributional_regressor.parameters()) if self._scheduler_spec: self._scheduler = self._scheduler_spec(self._optimiser) else: self._scheduler = None
def __call__( self, *, iterations: int = 1000, render_frequency: int = 100, stat_frequency: int = 10, disable_stdout: bool = False, metric_writer: Writer = MockWriter(), **kwargs, ): r""" :param log_directory: :param disable_stdout: Whether to disable stdout statements or not :type disable_stdout: bool :param iterations: How many iterations to train for :type iterations: int :param render_frequency: How often to render environment :type render_frequency: int :param stat_frequency: How often to write statistics :type stat_frequency: int :return: A training resume containing the trained agents models and some statistics :rtype: TR """ E = range(1, iterations) E = tqdm(E, desc="Rollout #", leave=False) best_episode_return = -math.inf for episode_i in E: initial_state = self.environment.reset() kwargs.update(render_environment=is_positive_and_mod_zero( render_frequency, episode_i)) ret, *_ = rollout_on_policy( self.agent, initial_state, self.environment, rollout_íth=episode_i, metric_writer=is_positive_and_mod_zero(stat_frequency, episode_i, ret=metric_writer), disable_stdout=disable_stdout, **kwargs, ) if best_episode_return < ret: best_episode_return = ret self.call_on_improvement_callbacks(**kwargs) if self.early_stop: break
def __build__( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @param critic: @param critic_optimiser: @param actor: @param actor_optimiser: @return: """ if action_space.is_discrete: raise ActionSpaceNotSupported() self._actor_arch_spec.kwargs["input_shape"] = self._input_shape self._actor_arch_spec.kwargs["output_shape"] = self._output_shape self._actor = self._actor_arch_spec().to(self._device) self._target_actor = copy.deepcopy(self._actor).to(self._device) freeze_model(self._target_actor, True, True) self._actor_optimiser = self._actor_optimiser_spec( self._actor.parameters()) self._critic_arch_spec.kwargs["input_shape"] = ( *self._input_shape, *self._output_shape, ) self._critic_arch_spec.kwargs["output_shape"] = 1 self._critic = self._critic_arch_spec().to(self._device) self._target_critic = copy.deepcopy(self._critic).to(self._device) freeze_model(self._target_critic, True, True) self._critic_optimiser = self._critic_optimiser_spec( self._critic.parameters()) self._random_process = self._random_process_spec( sigma=mean([r.span for r in action_space.ranges]))
def _sample(self, state: EnvironmentSnapshot, *args, deterministic: bool = False, metric_writer: Writer = MockWriter(), **kwargs) -> Any: """ @param state: @param args: @param deterministic: @param metric_writer: @param kwargs: @return: """ self._sample_i_since_last_update += 1 return self.action_space.sample()
def _sample( self, state: numpy.ndarray, *args, deterministic: bool = False, metric_writer: Writer = MockWriter(), **kwargs, ) -> Tuple[Any, ...]: """ @param state: @param args: @param deterministic: @param metric_writer: @param kwargs: @return: """ raise NotImplementedError
def _sample( self, state: Sequence, deterministic: bool = False, metric_writer: Writer = MockWriter(), ) -> numpy.ndarray: """ @param state: @param deterministic: @param metric_writer: @return: """ if not deterministic and self._exploration_sample( self._sample_i, metric_writer): return self._sample_random_process(state) return self._sample_model(state)
def _sample( self, state: Any, *args, deterministic: bool = False, metric_writer: Writer = MockWriter() ) -> Tuple[torch.Tensor, Any]: """ @param state: @param args: @param deterministic: @param metric_writer: @param kwargs: @return: """ distribution = self.actor(to_tensor(state, device=self._device)) with torch.no_grad(): return (torch.tanh(distribution.sample().detach()), distribution)
def update(self, trajectories, *args, metric_writer: Writer = MockWriter(), attempts=5, **kwargs) -> Any: """ Fit the linear baseline model (signal estimator) with the provided paths via damped least squares @param trajectories: @type trajectories: @param args: @type args: @param metric_writer: @type metric_writer: @param attempts: @type attempts: @param kwargs: @type kwargs: @return: @rtype: """ features_matrix = numpy.concatenate( [self.extract_features(trajectory) for trajectory in trajectories]) returns_matrix = numpy.concatenate( [trajectory["returns"] for trajectory in trajectories]) # returns_matrix = numpy.concatenate([path.returns for path in states]) c_regularisation_coeff = self._l2_reg_coefficient id_fm = numpy.identity(features_matrix.shape[1]) for _ in range(attempts): self._linear_coefficients = numpy.linalg.lstsq( features_matrix.T.dot(features_matrix) + c_regularisation_coeff * id_fm, features_matrix.T.dot(returns_matrix), rcond=-1, )[0] if not numpy.any(numpy.isnan(self._linear_coefficients)): break # Non-Nan solution found c_regularisation_coeff *= 10
def train_once(self, iteration_number: int, trajectories: Sequence, *, writer: Writer = MockWriter()): """Perform one step of policy optimization given one batch of samples. Args: iteration_number (int): Iteration number. trajectories (list[dict]): A list of collected paths. Returns: float: The average return in last epoch cycle. @param writer: @type writer: """ undiscounted_returns = [] for trajectory in TrajectoryBatch.from_trajectory_list( self._env_spec, trajectories).split(): # TODO: EEEEW undiscounted_returns.append(sum(trajectory.rewards)) sample_returns = np.mean(undiscounted_returns) self._all_returns.append(sample_returns) epoch = iteration_number // self._num_candidate_policies i_sample = iteration_number - epoch * self._num_candidate_policies writer.scalar("Epoch", epoch) writer.scalar("# Sample", i_sample) if ( iteration_number + 1 ) % self._num_candidate_policies == 0: # When looped all the way around update shared parameters, WARNING RACE CONDITIONS! sample_returns = max(self._all_returns) self.update() self.policy.set_param_values( self._shared_params[(i_sample + 1) % self._num_candidate_policies]) return sample_returns
def sample(self, state: EnvironmentSnapshot, *args, metric_writer: Writer = MockWriter(), **kwargs) -> Any: """ samples signal @param state: @type state: @param args: @type args: @param metric_writer: @type metric_writer: @param kwargs: @type kwargs: @return: @rtype: """ if self._linear_coefficients is None: return numpy.zeros(len(state["rewards"])) return self.extract_features(state).dot(self._linear_coefficients)
def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float: """ @param args: @param metric_writer: @param kwargs: @return: """ accum_loss = 0 for ith_inner_update in tqdm( range(self._num_inner_updates), desc="Inner update #", leave=False, postfix=f"Agent update #{self.update_i}" ): self.inner_update_i += 1 batch = self._memory_buffer.sample(self._batch_size) tensorised = TransitionPoint( *[to_tensor(a, device=self._device) for a in batch] ) with frozen_parameters(self.actor.parameters()): accum_loss += self.update_critics( tensorised, metric_writer=metric_writer ) with frozen_parameters( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) ): accum_loss += self.update_actor(tensorised, metric_writer=metric_writer) if is_zero_or_mod_zero(self._target_update_interval, self.inner_update_i): self.update_targets(self._copy_percentage, metric_writer=metric_writer) if metric_writer: metric_writer.scalar("Accum_loss", accum_loss, self.update_i) metric_writer.scalar("num_inner_updates_i", ith_inner_update, self.update_i) return accum_loss
def __build__( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @return: """ if action_space.is_mixed: raise ActionSpaceNotSupported() elif action_space.is_continuous: self._continuous_arch_spec.kwargs[ "input_shape"] = self._input_shape self._continuous_arch_spec.kwargs[ "output_shape"] = self._output_shape self.actor_critic = self._continuous_arch_spec().to(self._device) else: self._discrete_arch_spec.kwargs["input_shape"] = self._input_shape self._discrete_arch_spec.kwargs[ "output_shape"] = self._output_shape self.actor_critic = self._discrete_arch_spec().to(self._device) self._target_actor_critic = copy.deepcopy(self.actor_critic).to( self._device) freeze_model(self._target_actor_critic, True, True) self._optimiser = self._optimiser_spec(self.actor_critic.parameters())
def loss_function( reconstruction, original, mean, log_var, beta: Number = 1, writer: Writer = MockWriter(), ): """ Args: reconstruction: original: mean: log_var: beta: Returns: """ total_kld = kl_divergence(mean, log_var, writer) if True: beta_vae_loss = beta * total_kld else: beta_vae_loss = ( recon_loss + self.gamma * ( total_kld - torch.clamp( self.C_max / self.C_stop_iter * self.global_iter, 0, self.C_max.data[0], ) # C ).abs()) return reconstruction_loss(reconstruction, original) + beta_vae_loss
def train_siamese( model, optimiser, criterion, *, writer: Writer = MockWriter(), train_number_epochs, data_dir, train_batch_size, model_name, save_path, save_best=False, img_size, validation_interval: int = 1, ): """ :param data_dir: :type data_dir: :param optimiser: :type optimiser: :param criterion: :type criterion: :param writer: :type writer: :param model_name: :type model_name: :param save_path: :type save_path: :param save_best: :type save_best: :param model: :type model: :param train_number_epochs: :type train_number_epochs: :param train_batch_size: :type train_batch_size: :return: :rtype: Parameters ---------- img_size validation_interval""" train_dataloader = DataLoader( TripletDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.training, ), shuffle=True, num_workers=0, batch_size=train_batch_size, ) valid_dataloader = DataLoader( TripletDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.validation, ), shuffle=True, num_workers=0, batch_size=train_batch_size, ) best = math.inf E = tqdm(range(0, train_number_epochs)) batch_counter = count() for epoch in E: for tss in train_dataloader: batch_i = next(batch_counter) with TorchTrainSession(model): optimiser.zero_grad() loss_contrastive = criterion(*model( *[t.to(global_torch_device()) for t in tss])) loss_contrastive.backward() optimiser.step() a = loss_contrastive.cpu().item() writer.scalar("train_loss", a, batch_i) if batch_counter.__next__() % validation_interval == 0: with TorchEvalSession(model): for tsv in valid_dataloader: o = model(*[t.to(global_torch_device()) for t in tsv]) a_v = criterion(*o).cpu().item() valid_positive_acc = (accuracy( distances=pairwise_distance(o[0], o[1]), is_diff=0).cpu().item()) valid_negative_acc = (accuracy( distances=pairwise_distance(o[0], o[2]), is_diff=1).cpu().item()) valid_acc = numpy.mean( (valid_negative_acc, valid_positive_acc)) writer.scalar("valid_loss", a_v, batch_i) writer.scalar("valid_positive_acc", valid_positive_acc, batch_i) writer.scalar("valid_negative_acc", valid_negative_acc, batch_i) writer.scalar("valid_acc", valid_acc, batch_i) if a_v < best: best = a_v print(f"new best {best}") if save_best: save_model_parameters( model, optimiser=optimiser, model_name=model_name, save_directory=save_path, ) E.set_description( f"Epoch number {epoch}, Current train loss {a}, valid loss {a_v}, valid acc {valid_acc}" ) return model
from draugr.python_utilities.business import busy_indicator from draugr.writers import LogWriter, MockWriter, Writer from heimdallr import PROJECT_APP_PATH, PROJECT_NAME from heimdallr.configuration.heimdallr_config import ALL_CONSTANTS from heimdallr.configuration.heimdallr_settings import ( HeimdallrSettings, SettingScopeEnum, ) from heimdallr.utilities.gpu_utilities import pull_gpu_info from warg import NOD HOSTNAME = socket.gethostname() __all__ = ["main"] LOG_WRITER: Writer = MockWriter() def on_publish(client, userdata, result) -> None: """ """ global LOG_WRITER LOG_WRITER(result) def on_disconnect(client, userdata, rc): """ """ if rc != 0: print("Unexpected MQTT disconnection. Will auto-reconnect") client.reconnect()
def train_siamese( model: Module, optimiser: Optimizer, criterion: callable, *, writer: Writer = MockWriter(), train_number_epochs: int, data_dir: Path, train_batch_size: int, model_name: str, save_path: Path, save_best: bool = False, img_size: Tuple[int, int], validation_interval: int = 1, ): """ :param img_size: :type img_size: :param validation_interval: :type validation_interval: :param data_dir: :type data_dir: :param optimiser: :type optimiser: :param criterion: :type criterion: :param writer: :type writer: :param model_name: :type model_name: :param save_path: :type save_path: :param save_best: :type save_best: :param model: :type model: :param train_number_epochs: :type train_number_epochs: :param train_batch_size: :type train_batch_size: :return: :rtype: """ train_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Training, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) valid_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Validation, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) best = math.inf E = tqdm(range(0, train_number_epochs)) batch_counter = count() for epoch in E: for tss in train_dataloader: batch_i = next(batch_counter) with TorchTrainSession(model): o = [t.to(global_torch_device()) for t in tss] optimiser.zero_grad() loss_contrastive = criterion(model(*o[:2]), o[2].to(dtype=torch.float)) loss_contrastive.backward() optimiser.step() train_loss = loss_contrastive.cpu().item() writer.scalar("train_loss", train_loss, batch_i) if batch_counter.__next__() % validation_interval == 0: with TorchEvalSession(model): for tsv in valid_dataloader: ov = [t.to(global_torch_device()) for t in tsv] v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float) valid_loss = criterion(v_o, fact).cpu().item() valid_accuracy = (accuracy(distances=v_o, is_diff=fact).cpu().item()) writer.scalar("valid_loss", valid_loss, batch_i) if valid_loss < best: best = valid_loss print(f"new best {best}") writer.blip("new_best", batch_i) if save_best: save_model_parameters( model, optimiser=optimiser, model_name=model_name, save_directory=save_path, ) E.set_description( f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}" ) return model
def __call__(self, *, num_environment_steps=500000, batch_size=128, stat_frequency=10, render_frequency=10000, initial_observation_period=1000, render_duration=1000, update_agent_frequency: int = 1, disable_stdout: bool = False, train_agent: bool = True, metric_writer: Writer = MockWriter(), rollout_drawer: MplDrawer = MockDrawer(), **kwargs) -> None: """ :param log_directory: :param num_environment_steps: :param stat_frequency: :param render_frequency: :param disable_stdout: :return: """ state = self.agent.extract_features(self.environment.reset()) running_signal = mean_accumulator() best_running_signal = None running_mean_action = mean_accumulator() termination_i = 0 signal_since_last_termination = 0 duration_since_last_termination = 0 for step_i in tqdm(range(num_environment_steps), desc="Step #", leave=False): sample = self.agent.sample(state) action = self.agent.extract_action(sample) snapshot = self.environment.react(action) successor_state = self.agent.extract_features(snapshot) signal = self.agent.extract_signal(snapshot) terminated = snapshot.terminated if train_agent: self.agent.remember( state=state, signal=signal, terminated=terminated, sample=sample, successor_state=successor_state, ) state = successor_state duration_since_last_termination += 1 mean_signal = signal.mean().item() signal_since_last_termination += mean_signal running_mean_action.send(action.mean()) running_signal.send(mean_signal) if (train_agent and is_positive_and_mod_zero( update_agent_frequency * batch_size, step_i) and len(self.agent.memory_buffer) > batch_size and step_i > initial_observation_period): loss = self.agent.update(metric_writer=metric_writer) sig = next(running_signal) if not best_running_signal or sig > best_running_signal: best_running_signal = sig self.call_on_improvement_callbacks(loss=loss, signal=sig, **kwargs) if terminated.any(): termination_i += 1 if metric_writer: metric_writer.scalar( "duration_since_last_termination", duration_since_last_termination, ) metric_writer.scalar("signal_since_last_termination", signal_since_last_termination) metric_writer.scalar("running_mean_action", next(running_mean_action)) metric_writer.scalar("running_signal", next(running_signal)) signal_since_last_termination = 0 duration_since_last_termination = 0 if (is_zero_or_mod_below(render_frequency, render_duration, step_i) and render_frequency != 0): self.environment.render() if rollout_drawer: rollout_drawer.draw(action) if self.early_stop: break