def single_epoch_fitting( model: torch.nn.Module, optimiser, train_loader_, *, epoch: int = None, writer: Writer = None, device_: torch.device = global_torch_device()) -> None: accum_loss = 0 num_batches = len(train_loader_) with TorchTrainSession(model): for batch_idx, (data, target) in tqdm(enumerate(train_loader_), desc='train batch #', total=num_batches): loss = nll_loss( model(data.to(device_)).squeeze(), target.to(device_) ) # negative log-likelihood for a tensor of size (batch x 1 x n_output) optimiser.zero_grad() loss.backward() optimiser.step() accum_loss += loss.item() if writer: writer.scalar('loss', accum_loss / num_batches, epoch)
def update_targets( self, copy_percentage: float = 0.005, *, metric_writer: Writer = None ) -> None: """ Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to: \theta_{\text{targ}} \leftarrow \rho \theta_{\text{targ}} + (1-\rho) \theta where \rho is polyak. (Always between 0 and 1, usually close to 1.) @param metric_writer: @type metric_writer: @param copy_percentage: @return: """ if metric_writer: metric_writer.blip("Target Models Synced", self.update_i) update_target( target_model=self.critic_1_target, source_model=self.critic_1, copy_percentage=copy_percentage, ) update_target( target_model=self.critic_2_target, source_model=self.critic_2, copy_percentage=copy_percentage, )
def kl_divergence(mean, log_var, writer: Writer = MockWriter()) -> torch.Tensor: """ Args: mean: log_var: Returns: """ batch_size = mean.size(0) assert batch_size != 0 if mean.data.ndimension() == 4: mean = mean.view(mean.size(0), mean.size(1)) if log_var.data.ndimension() == 4: log_var = log_var.view(log_var.size(0), log_var.size(1)) klds = -0.5 * (1 + log_var - mean.pow(2) - log_var.exp()) if writer: writer.scalar("dimension_wise_kld", klds.mean(0)) writer.scalar("mean_kld", klds.mean(1).mean(0, True)) return klds.sum(1).mean(0, True) # total_kld
def update_alpha( self, log_prob: torch.Tensor, metric_writer: Writer = None ) -> float: """ @param log_prob: @type log_prob: @param tensorised: @param metric_writer: @return: """ assert not log_prob.requires_grad alpha_loss = -torch.mean( self._log_sac_alpha * (log_prob + self._target_entropy) ) self.sac_alpha_optimiser.zero_grad() alpha_loss.backward() self.post_process_gradients(self._log_sac_alpha) self.sac_alpha_optimiser.step() self._sac_alpha = self._log_sac_alpha.exp() out_loss = alpha_loss.detach().cpu().item() if metric_writer: metric_writer.scalar("Sac_Alpha_Loss", out_loss, self.update_i) metric_writer.scalar("Sac_Alpha", to_scalar(self._sac_alpha), self.update_i) return out_loss
def _update(self, metric_writer: Writer = MockWriter()) -> float: """ @param metric_writer: @return: """ transitions = self._prepare_transitions() accum_loss = mean_accumulator() for ith_inner_update in tqdm(range(self._num_inner_updates), desc="#Inner updates", leave=False): self.inner_update_i += 1 loss, early_stop_inner = self.inner_update( *transitions, metric_writer=metric_writer) accum_loss.send(loss) if is_none_or_zero_or_negative_or_mod_zero( self._update_target_interval, self.inner_update_i): self._update_targets(self._copy_percentage, metric_writer=metric_writer) if early_stop_inner: break mean_loss = next(accum_loss) if metric_writer: metric_writer.scalar("Inner Updates", ith_inner_update) metric_writer.scalar("Mean Loss", mean_loss) return mean_loss
def train_model( model, optimiser, epoch_i: int, metric_writer: Writer, loader: DataLoader, log_interval=10, ): with TorchTrainSession(model): train_accum_loss = 0 generator = tqdm(enumerate(loader)) for batch_idx, (original, *_) in generator: original = original.to(global_torch_device()) optimiser.zero_grad() reconstruction, mean, log_var = model(original) loss = loss_function(reconstruction, original, mean, log_var) loss.backward() optimiser.step() train_accum_loss += loss.item() metric_writer.scalar("train_loss", loss.item()) if batch_idx % log_interval == 0: generator.set_description( f"Train Epoch: {epoch_i}" f" [{batch_idx * len(original)}/" f"{len(loader.dataset)}" f" ({100. * batch_idx / len(loader):.0f}%)]\t" f"Loss: {loss.item() / len(original):.6f}") break print(f"====> Epoch: {epoch_i}" f" Average loss: {train_accum_loss / len(loader.dataset):.4f}")
def maskrcnn_evaluate( model: Module, data_loader: DataLoader, *, device=global_torch_device(), writer: Writer = None, ) -> CocoEvaluator: """ Args: model: data_loader: device: writer: Returns: """ n_threads = torch.get_num_threads() # FIXME remove this and make paste_masks_in_image run on the GPU torch.set_num_threads(1) cpu_device = torch.device("cpu") coco_evaluator = CocoEvaluator( get_coco_api_from_dataset(data_loader.dataset), get_iou_types(model)) with torch.no_grad(): with TorchEvalSession(model): for image, targets in tqdm.tqdm(data_loader): image = [img.to(device) for img in image] targets = [{k: v.to(device) for k, v in t.items()} for t in targets] torch.cuda.synchronize(device) model_time = time.time() outputs = model(image) outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] model_time = time.time() - model_time res = { target["image_id"].item(): output for target, output in zip(targets, outputs) } evaluator_time = time.time() coco_evaluator.update(res) evaluator_time = time.time() - evaluator_time if writer: writer.scalar("model_time", model_time) writer.scalar("evaluator_time", evaluator_time) coco_evaluator.synchronize_between_processes() coco_evaluator.accumulate() coco_evaluator.summarize() torch.set_num_threads(n_threads) return coco_evaluator
def build( self, observation_space: ObservationSpace, action_space: ActionSpace, signal_space: SignalSpace, *, metric_writer: Writer = MockWriter(), print_model_repr: bool = True, verbose: bool = False, **kwargs, ) -> None: """ @param observation_space: @param action_space: @param signal_space: @param metric_writer: @param print_model_repr: @param kwargs: @return: :param verbose: """ super().build( observation_space, action_space, signal_space, print_model_repr=print_model_repr, metric_writer=metric_writer, **kwargs, ) if print_model_repr: for k, w in self.models.items(): sprint(f"{k}: {w}", highlight=True, color="cyan") if metric_writer: try: model = copy.deepcopy(w).to("cpu") dummy_input = model.sample_input() sprint(f'{k} input: {dummy_input.shape}') import contextlib with contextlib.redirect_stdout( None ): # So much useless frame info printed... Suppress it if isinstance(metric_writer, GraphWriterMixin): metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment... except RuntimeError as ex: sprint( f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}", color="red", highlight=True, )
def _update(self, *, metric_writer: Writer = MockWriter()) -> None: """ Update :return: :rtype: """ tensorised = TransitionPoint(*[ to_tensor(a, device=self._device) for a in self._memory_buffer.sample() ]) self._memory_buffer.clear() # Compute next Q value based on which action target actor would choose # Detach variable from the current graph since we don't want gradients for next Q to propagated with torch.no_grad(): next_max_q = self._target_critic( tensorised.successor_state, self._target_actor(tensorised.state)) Q_target = tensorised.signal + (self._discount_factor * next_max_q * tensorised.non_terminal_numerical) # Compute the target of the current Q values # Compute current Q value, critic takes state and action chosen td_error = self._critic_criteria( self._critic(tensorised.state, tensorised.action), Q_target.detach()) self._critic_optimiser.zero_grad() td_error.backward() self.post_process_gradients(self._critic.parameters()) self._critic_optimiser.step() with frozen_model(self._critic): policy_loss = -torch.mean( self._critic(tensorised.state, self._actor(tensorised.state))) self._actor_optimiser.zero_grad() policy_loss.backward() self.post_process_gradients(self._actor.parameters()) self._actor_optimiser.step() if is_zero_or_mod_zero(self._update_target_interval, self.update_i): self.update_targets(self._copy_percentage, metric_writer=metric_writer) if metric_writer: metric_writer.scalar("td_error", td_error.cpu().item()) metric_writer.scalar("critic_loss", policy_loss.cpu().item()) with torch.no_grad(): return (td_error + policy_loss).cpu().item()
def maskrcnn_train_single_epoch( *, model: Module, optimiser: torch.optim.Optimizer, data_loader: DataLoader, device: torch.device = global_torch_device(), writer: Writer = None, ) -> None: """ :param model: :param optimiser: :param data_loader: :param epoch_i: :param log_frequency: :param device: :param writer: :return: """ model.to(device) with TorchTrainSession(model): for images, targets in tqdm.tqdm(data_loader, desc="Batch #"): images = [img.to(device) for img in images] targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # torch.cuda.synchronize(device) loss_dict = model(images, targets=targets) losses = sum(loss for loss in loss_dict.values()) loss_dict_reduced = reduce_dict( loss_dict) # reduce losses over all GPUs for logging purposes losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): print(f"Loss is {loss_value}, stopping training") print(loss_dict_reduced) sys.exit(1) optimiser.zero_grad() losses.backward() optimiser.step() if writer: for k, v in { "loss": losses_reduced, "lr": torch.optim.Optimizer.param_groups[0]["lr"], **loss_dict_reduced, }.items(): writer.scalar(k, v)
def test_model( model: VAE, epoch_i: int, metric_writer: Writer, loader: DataLoader, save_images: bool = True, ): global LOWEST_L with TorchEvalSession(model): test_accum_loss = 0 with torch.no_grad(): for i, (original, labels, *_) in enumerate(loader): original = original.to(global_torch_device()) reconstruction, mean, log_var = model(original) loss = loss_function(reconstruction, original, mean, log_var).item() test_accum_loss += loss metric_writer.scalar("test_loss", test_accum_loss) if save_images: if i == 0: n = min(original.size(0), 8) comparison = torch.cat( [original[:n], reconstruction[:n]]) save_image( comparison.cpu(), # Torch save images str(BASE_PATH / f"reconstruction_{str(epoch_i)}.png"), nrow=n, ) """ scatter_plot_encoding_space(str(BASE_PATH / f'encoding_space_{str(epoch_i)}.png'), mean.to('cpu').numpy(), log_var.to('cpu').numpy(), labels) """ break # test_loss /= len(loader.dataset) test_accum_loss /= loader.batch_size print(f"====> Test set loss: {test_accum_loss:.4f}") torch.save(model.state_dict(), BASE_PATH / f"model_state_dict{str(epoch_i)}.pth") if LOWEST_L > test_accum_loss: LOWEST_L = test_accum_loss torch.save(model.state_dict(), BASE_PATH / f"best_state_dict.pth")
def _update(self, *, metric_writer: Writer = MockWriter()) -> None: """ @param metric_writer: @return: """ loss_ = math.inf if self.update_i > self._initial_observation_period: if is_zero_or_mod_zero(self._learning_frequency, self.update_i): if len(self._memory_buffer) > self._batch_size: transitions = self._memory_buffer.sample(self._batch_size) td_error, Q_expected, Q_state = self._td_error(transitions) td_error = td_error.detach().squeeze(-1).cpu().numpy() if self._use_per: self._memory_buffer.update_last_batch(td_error) loss = self._loss_function(Q_state, Q_expected) self._optimiser.zero_grad() loss.backward() self.post_process_gradients(self.value_model.parameters()) self._optimiser.step() loss_ = to_scalar(loss) if metric_writer: metric_writer.scalar("td_error", td_error.mean(), self.update_i) metric_writer.scalar("loss", loss_, self.update_i) if self._scheduler: self._scheduler.step() if metric_writer: for i, param_group in enumerate( self._optimiser.param_groups): metric_writer.scalar(f"lr{i}", param_group["lr"], self.update_i) else: logging.info( "Batch size is larger than current memory size, skipping update" ) if self._double_dqn: if is_zero_or_mod_zero(self._sync_target_model_frequency, self.update_i): update_target( target_model=self._target_value_model, source_model=self.value_model, copy_percentage=self._copy_percentage, ) if metric_writer: metric_writer.blip("Target Model Synced", self.update_i) return loss_
def write_metrics_recursive(eval_result: typing.Mapping, prefix: str, summary_writer: Writer, global_step: int) -> None: """ :param eval_result: :param prefix: :param summary_writer: :param global_step: """ for key in eval_result: value = eval_result[key] tag = f"{prefix}/{key}" if isinstance(value, typing.Mapping): write_metrics_recursive(value, tag, summary_writer, global_step) else: summary_writer.scalar(tag, value, step_i=global_step)
def _update_targets(self, copy_percentage: float, *, metric_writer: Writer = None) -> None: """ @param copy_percentage: @return: """ if metric_writer: metric_writer.blip("Target Model Synced", self.update_i) update_target( target_model=self._target_actor_critic, source_model=self.actor_critic, copy_percentage=copy_percentage, )
def update_critics( self, tensorised: TransitionPoint, metric_writer: Writer = None ) -> float: """ @param metric_writer: @param tensorised: @return: """ with torch.no_grad(): successor_action, successor_log_prob = normal_tanh_reparameterised_sample( self.actor(tensorised.successor_state) ) min_successor_q = ( torch.min( self.critic_1_target(tensorised.successor_state, successor_action), self.critic_2_target(tensorised.successor_state, successor_action), ) - successor_log_prob * self._sac_alpha ) successor_q_value = ( tensorised.signal + tensorised.non_terminal_numerical * self._discount_factor * min_successor_q ).detach() assert not successor_q_value.requires_grad q_value_loss1 = self._critic_criterion( self.critic_1(tensorised.state, tensorised.action), successor_q_value ) q_value_loss2 = self._critic_criterion( self.critic_2(tensorised.state, tensorised.action), successor_q_value ) critic_loss = q_value_loss1 + q_value_loss2 assert critic_loss.requires_grad self.critic_optimiser.zero_grad() critic_loss.backward() self.post_process_gradients(self.critic_1.parameters()) self.post_process_gradients(self.critic_2.parameters()) self.critic_optimiser.step() out_loss = to_scalar(critic_loss) if metric_writer: metric_writer.scalar("Critics_loss", out_loss, self.update_i) metric_writer.scalar("q_value_loss1", to_scalar(q_value_loss1), self.update_i) metric_writer.scalar("q_value_loss2", to_scalar(q_value_loss2), self.update_i) metric_writer.scalar("min_successor_q", to_scalar(min_successor_q), self.update_i) metric_writer.scalar("successor_q_value", to_scalar(successor_q_value), self.update_i) return out_loss
def single_epoch_evaluation( model: Module, evaluation_loader: DataLoader, subset: Split, *, epoch: int = None, writer: Writer = None, device: torch.device = global_torch_device()) -> float: correct = 0 num_batches = len(evaluation_loader) with TorchEvalSession(model): for data, target in tqdm(evaluation_loader, desc=f'{subset} batch #', total=num_batches): correct += model(data.to(device)).argmax(dim=-1).squeeze().eq( target.to(device)).sum().item() acc = correct / len(evaluation_loader.dataset) if writer: writer.scalar(f'{subset}_accuracy', acc, epoch) return acc
def _policy_loss(self, new_distribution, action_batch, log_prob_batch_old, adv_batch, *, metric_writer: Writer = None): action_log_probs_new = self.get_log_prob(new_distribution, action_batch) ratio = torch.exp(action_log_probs_new - log_prob_batch_old) # if ratio explodes to (inf or Nan) due to the residual being to large check initialisation! # Generated action probabilities from (new policy) and (old policy). # Values of [0..1] means that actions less likely with the new policy, # while values [>1] mean action a more likely now clamped_ratio = torch.clamp( ratio, min=1.0 - self._surrogate_clipping_value, max=1.0 + self._surrogate_clipping_value, ) policy_loss = -torch.min(ratio * adv_batch, clamped_ratio * adv_batch).mean() entropy_loss = new_distribution.entropy().mean( ) * self._entropy_reg_coefficient with torch.no_grad(): approx_kl = to_scalar((log_prob_batch_old - action_log_probs_new)) if metric_writer: metric_writer.scalar("ratio", to_scalar(ratio)) metric_writer.scalar("entropy_loss", to_scalar(entropy_loss)) metric_writer.scalar("clamped_ratio", to_scalar(clamped_ratio)) return policy_loss - entropy_loss, approx_kl
def sample(self, signals: numpy.ndarray, states: numpy.ndarray, actions: numpy.ndarray, *, writer: Writer = None) -> numpy.ndarray: """ @param signals: @type signals: @param states: @type states: @param actions: @type actions: @param writer: @type writer: @return: @rtype: """ n, t = actions.shape[0], actions.shape[1] states, next_states = states[:, :-1], states[:, 1:] states = to_tensor( states.reshape(states.shape[0] * states.shape[1], -1)) # flatten next_states = to_tensor( next_states.reshape(states.shape[0] * states.shape[1], -1)) actions = to_tensor(actions.reshape(n * t, *actions.shape[2:])) next_states_latent, next_states_hat, _ = self.forward( states, next_states, actions) intrinsic_signal = ( (self.reward_scale / 2 * (next_states_hat - next_states_latent).norm( 2, dim=-1).pow(2)).cpu().detach().numpy().reshape(n, t)) if writer is not None: writer.scalar("icm/signal", intrinsic_signal.mean().item()) return (1.0 - self.intrinsic_signal_factor ) * signals + self.intrinsic_signal_factor * intrinsic_signal
def train_once(self, iteration_number: int, trajectories: Sequence, *, writer: Writer = MockWriter()): """Perform one step of policy optimization given one batch of samples. Args: iteration_number (int): Iteration number. trajectories (list[dict]): A list of collected paths. Returns: float: The average return in last epoch cycle. @param writer: @type writer: """ undiscounted_returns = [] for trajectory in TrajectoryBatch.from_trajectory_list( self._env_spec, trajectories).split(): # TODO: EEEEW undiscounted_returns.append(sum(trajectory.rewards)) sample_returns = np.mean(undiscounted_returns) self._all_returns.append(sample_returns) epoch = iteration_number // self._num_candidate_policies i_sample = iteration_number - epoch * self._num_candidate_policies writer.scalar("Epoch", epoch) writer.scalar("# Sample", i_sample) if ( iteration_number + 1 ) % self._num_candidate_policies == 0: # When looped all the way around update shared parameters, WARNING RACE CONDITIONS! sample_returns = max(self._all_returns) self.update() self.policy.set_param_values( self._shared_params[(i_sample + 1) % self._num_candidate_policies]) return sample_returns
def update_targets(self, update_percentage: float, *, metric_writer: Writer = None) -> None: """ @param update_percentage: @return: """ with torch.no_grad(): if metric_writer: metric_writer.blip("Target Model Synced", self.update_i) update_target( target_model=self._target_critic, source_model=self._critic, copy_percentage=update_percentage, ) update_target( target_model=self._target_actor, source_model=self._actor, copy_percentage=update_percentage, )
def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float: """ @param args: @param metric_writer: @param kwargs: @return: """ accum_loss = 0 for ith_inner_update in tqdm( range(self._num_inner_updates), desc="Inner update #", leave=False, postfix=f"Agent update #{self.update_i}" ): self.inner_update_i += 1 batch = self._memory_buffer.sample(self._batch_size) tensorised = TransitionPoint( *[to_tensor(a, device=self._device) for a in batch] ) with frozen_parameters(self.actor.parameters()): accum_loss += self.update_critics( tensorised, metric_writer=metric_writer ) with frozen_parameters( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) ): accum_loss += self.update_actor(tensorised, metric_writer=metric_writer) if is_zero_or_mod_zero(self._target_update_interval, self.inner_update_i): self.update_targets(self._copy_percentage, metric_writer=metric_writer) if metric_writer: metric_writer.scalar("Accum_loss", accum_loss, self.update_i) metric_writer.scalar("num_inner_updates_i", ith_inner_update, self.update_i) return accum_loss
def loss(self, policy_loss: torch.Tensor, states: torch.Tensor, next_states: torch.Tensor, actions: torch.Tensor, *, writer: Writer = None) -> torch.Tensor: """ @param policy_loss: @type policy_loss: @param states: @type states: @param next_states: @type next_states: @param actions: @type actions: @param writer: @type writer: @return: @rtype: """ next_states_latent, next_states_hat, actions_hat = self.forward( states, next_states, actions) forward_loss = (0.5 * (next_states_hat - next_states_latent.detach()).norm( 2, dim=-1).pow(2).mean()) ca = Categorical(logits=actions_hat).sample() inverse_loss = self.a_loss(ca, actions) curiosity_loss = self.weight * forward_loss + ( 1 - self.weight) * inverse_loss if writer is not None: writer.scalar("icm/loss", curiosity_loss.item()) return self.policy_weight * policy_loss + curiosity_loss
def update_actor( self, tensorised: torch.Tensor, metric_writer: Writer = None ) -> float: """ @param tensorised: @param metric_writer: @return: """ dist = self.actor(tensorised.state) action, log_prob = normal_tanh_reparameterised_sample(dist) # Check gradient paths assert action.requires_grad assert log_prob.requires_grad q_values = ( self.critic_1(tensorised.state, action), self.critic_2(tensorised.state, action), ) assert q_values[0].requires_grad and q_values[1].requires_grad policy_loss = torch.mean(self._sac_alpha * log_prob - torch.min(*q_values)) self.actor_optimiser.zero_grad() policy_loss.backward() self.post_process_gradients(self.actor.parameters()) self.actor_optimiser.step() out_loss = to_scalar(policy_loss) if metric_writer: metric_writer.scalar("Policy_loss", out_loss) metric_writer.scalar("q_value_1", to_scalar(q_values[0])) metric_writer.scalar("q_value_2", to_scalar(q_values[1])) metric_writer.scalar("policy_stddev", to_scalar(dist.stddev)) metric_writer.scalar("policy_log_prob", to_scalar(log_prob)) if self._auto_tune_sac_alpha: out_loss += self.update_alpha( log_prob.detach(), metric_writer=metric_writer ) return out_loss
def inner_update(self, *transitions, metric_writer: Writer = None) -> Tuple: batch_generator = shuffled_batches(*transitions, size=transitions[0].size(0), batch_size=self._mini_batch_size) for ( state, action, log_prob_old, discounted_signal, advantage, ) in batch_generator: new_distribution, value_estimate = self.actor_critic(state) policy_loss, approx_kl = self._policy_loss( new_distribution, action, log_prob_old, advantage, metric_writer=metric_writer, ) critic_loss = ( self._critic_criterion(value_estimate, discounted_signal) * self._value_reg_coefficient) loss = policy_loss + critic_loss self._optimiser.zero_grad() loss.backward() self.post_process_gradients(self.actor_critic.parameters()) self._optimiser.step() if metric_writer: metric_writer.scalar("policy_stddev", to_scalar(new_distribution.stddev)) metric_writer.scalar("policy_loss", to_scalar(policy_loss)) metric_writer.scalar("critic_loss", to_scalar(critic_loss)) metric_writer.scalar("policy_approx_kl", approx_kl) metric_writer.scalar("merged_loss", to_scalar(loss)) if approx_kl > 1.5 * self._target_kl: return to_scalar(loss), True return to_scalar(loss), False
def train_siamese( model, optimiser, criterion, *, writer: Writer = MockWriter(), train_number_epochs, data_dir, train_batch_size, model_name, save_path, save_best=False, img_size, validation_interval: int = 1, ): """ :param data_dir: :type data_dir: :param optimiser: :type optimiser: :param criterion: :type criterion: :param writer: :type writer: :param model_name: :type model_name: :param save_path: :type save_path: :param save_best: :type save_best: :param model: :type model: :param train_number_epochs: :type train_number_epochs: :param train_batch_size: :type train_batch_size: :return: :rtype: Parameters ---------- img_size validation_interval""" train_dataloader = DataLoader( TripletDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.training, ), shuffle=True, num_workers=0, batch_size=train_batch_size, ) valid_dataloader = DataLoader( TripletDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.validation, ), shuffle=True, num_workers=0, batch_size=train_batch_size, ) best = math.inf E = tqdm(range(0, train_number_epochs)) batch_counter = count() for epoch in E: for tss in train_dataloader: batch_i = next(batch_counter) with TorchTrainSession(model): optimiser.zero_grad() loss_contrastive = criterion(*model( *[t.to(global_torch_device()) for t in tss])) loss_contrastive.backward() optimiser.step() a = loss_contrastive.cpu().item() writer.scalar("train_loss", a, batch_i) if batch_counter.__next__() % validation_interval == 0: with TorchEvalSession(model): for tsv in valid_dataloader: o = model(*[t.to(global_torch_device()) for t in tsv]) a_v = criterion(*o).cpu().item() valid_positive_acc = (accuracy( distances=pairwise_distance(o[0], o[1]), is_diff=0).cpu().item()) valid_negative_acc = (accuracy( distances=pairwise_distance(o[0], o[2]), is_diff=1).cpu().item()) valid_acc = numpy.mean( (valid_negative_acc, valid_positive_acc)) writer.scalar("valid_loss", a_v, batch_i) writer.scalar("valid_positive_acc", valid_positive_acc, batch_i) writer.scalar("valid_negative_acc", valid_negative_acc, batch_i) writer.scalar("valid_acc", valid_acc, batch_i) if a_v < best: best = a_v print(f"new best {best}") if save_best: save_model_parameters( model, optimiser=optimiser, model_name=model_name, save_directory=save_path, ) E.set_description( f"Epoch number {epoch}, Current train loss {a}, valid loss {a_v}, valid acc {valid_acc}" ) return model
def rollout_on_policy( agent: Agent, initial_snapshot: EnvironmentSnapshot, env: Environment, *, rollout_ith: int = None, render_environment: bool = False, metric_writer: Writer = MockWriter(), rollout_drawer: MplDrawer = MockDrawer(), train_agent: bool = True, max_length: int = None, disable_stdout: bool = False, ): """Perform a single rollout until termination in environment :param rollout_ith: :param agent: :param rollout_drawer: :param disable_stdout: :param metric_writer: :type max_length: int :param max_length: :type train_agent: bool :type render_environment: bool :param initial_snapshot: The initial state observation in the environment :param env: The environment the agent interacts with :param render_environment: Whether to render environment interaction :param train_agent: Whether the agent should use the rollout to update its model :return: -episode_signal (:py:class:`float`) - first output -episode_length- -average_episode_entropy- """ state = agent.extract_features(initial_snapshot) running_mean_action = mean_accumulator() episode_signal = total_accumulator() rollout_description = f'Rollout' if rollout_ith: rollout_description += f" #{rollout_ith}" for step_i in tqdm(count(1), rollout_description, unit='th step', leave=False, disable=disable_stdout, postfix=f"Agent update #{agent.update_i}"): sample = agent.sample(state) action = agent.extract_action(sample) snapshot = env.react(action) successor_state = agent.extract_features(snapshot) terminated = snapshot.terminated signal = agent.extract_signal(snapshot) if train_agent: agent.remember( state=state, signal=signal, terminated=terminated, sample=sample, successor_state=successor_state, ) state = successor_state running_mean_action.send(action.mean()) episode_signal.send(signal.mean()) if render_environment: env.render() if rollout_drawer: if env.action_space.is_discrete: action = to_one_hot(agent.output_shape, action) rollout_drawer.draw(action) if numpy.array(terminated).all() or (max_length and step_i > max_length): break if train_agent: agent.update(metric_writer=metric_writer) else: logging.info("no update") episode_return = next(episode_signal) rma = next(running_mean_action) if metric_writer: metric_writer.scalar("duration", step_i, agent.update_i) metric_writer.scalar("running_mean_action", rma, agent.update_i) metric_writer.scalar("signal", episode_return, agent.update_i) return episode_return, step_i
def __call__(self, *, batch_size=1000, iterations=10000, stat_frequency=10, render_frequency=10, disable_stdout: bool = False, train_agent: bool = True, metric_writer: Writer = MockWriter(), **kwargs) -> None: """ :param log_directory: :param num_steps: :param iterations: :param stat_frequency: :param render_frequency: :param disable_stdout: :return: @rtype: object @param batch_size: @param log_directory: @param iterations: @param stat_frequency: @param render_frequency: @param disable_stdout: @param train_agent: @param kwargs: """ state = self.agent.extract_features(self.environment.reset()) running_signal = mean_accumulator() best_running_signal = None running_mean_action = mean_accumulator() for batch_i in tqdm(range(1, iterations), leave=False, disable=disable_stdout, desc="Batch #", postfix=f"Agent update #{self.agent.update_i}"): for _ in tqdm( range(batch_size), leave=False, disable=disable_stdout, desc="Step #", ): sample = self.agent.sample(state) action = self.agent.extract_action(sample) snapshot = self.environment.react(action) successor_state = self.agent.extract_features(snapshot) signal = self.agent.extract_signal(snapshot) if is_positive_and_mod_zero(render_frequency, batch_i): self.environment.render() if train_agent: self.agent.remember( state=state, signal=signal, terminated=snapshot.terminated, sample=sample, successor_state=successor_state, ) state = successor_state running_signal.send(signal.mean()) running_mean_action.send(action.mean()) sig = next(running_signal) rma = next(running_mean_action) if is_positive_and_mod_zero(stat_frequency, batch_i): metric_writer.scalar("Running signal", sig, batch_i) metric_writer.scalar("running_mean_action", rma, batch_i) if train_agent: loss = self.agent.update(metric_writer=metric_writer) if sig > best_running_signal: best_running_signal = sig self.call_on_improvement_callbacks(loss=loss, **kwargs) else: logging.info("no update") if self.early_stop: break
def __call__(self, *, num_environment_steps=500000, batch_size=128, stat_frequency=10, render_frequency=10000, initial_observation_period=1000, render_duration=1000, update_agent_frequency: int = 1, disable_stdout: bool = False, train_agent: bool = True, metric_writer: Writer = MockWriter(), rollout_drawer: MplDrawer = MockDrawer(), **kwargs) -> None: """ :param log_directory: :param num_environment_steps: :param stat_frequency: :param render_frequency: :param disable_stdout: :return: """ state = self.agent.extract_features(self.environment.reset()) running_signal = mean_accumulator() best_running_signal = None running_mean_action = mean_accumulator() termination_i = 0 signal_since_last_termination = 0 duration_since_last_termination = 0 for step_i in tqdm(range(num_environment_steps), desc="Step #", leave=False): sample = self.agent.sample(state) action = self.agent.extract_action(sample) snapshot = self.environment.react(action) successor_state = self.agent.extract_features(snapshot) signal = self.agent.extract_signal(snapshot) terminated = snapshot.terminated if train_agent: self.agent.remember( state=state, signal=signal, terminated=terminated, sample=sample, successor_state=successor_state, ) state = successor_state duration_since_last_termination += 1 mean_signal = signal.mean().item() signal_since_last_termination += mean_signal running_mean_action.send(action.mean()) running_signal.send(mean_signal) if (train_agent and is_positive_and_mod_zero( update_agent_frequency * batch_size, step_i) and len(self.agent.memory_buffer) > batch_size and step_i > initial_observation_period): loss = self.agent.update(metric_writer=metric_writer) sig = next(running_signal) if not best_running_signal or sig > best_running_signal: best_running_signal = sig self.call_on_improvement_callbacks(loss=loss, signal=sig, **kwargs) if terminated.any(): termination_i += 1 if metric_writer: metric_writer.scalar( "duration_since_last_termination", duration_since_last_termination, ) metric_writer.scalar("signal_since_last_termination", signal_since_last_termination) metric_writer.scalar("running_mean_action", next(running_mean_action)) metric_writer.scalar("running_signal", next(running_signal)) signal_since_last_termination = 0 duration_since_last_termination = 0 if (is_zero_or_mod_below(render_frequency, render_duration, step_i) and render_frequency != 0): self.environment.render() if rollout_drawer: rollout_drawer.draw(action) if self.early_stop: break
def train_siamese( model: Module, optimiser: Optimizer, criterion: callable, *, writer: Writer = MockWriter(), train_number_epochs: int, data_dir: Path, train_batch_size: int, model_name: str, save_path: Path, save_best: bool = False, img_size: Tuple[int, int], validation_interval: int = 1, ): """ :param img_size: :type img_size: :param validation_interval: :type validation_interval: :param data_dir: :type data_dir: :param optimiser: :type optimiser: :param criterion: :type criterion: :param writer: :type writer: :param model_name: :type model_name: :param save_path: :type save_path: :param save_best: :type save_best: :param model: :type model: :param train_number_epochs: :type train_number_epochs: :param train_batch_size: :type train_batch_size: :return: :rtype: """ train_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Training, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) valid_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Validation, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) best = math.inf E = tqdm(range(0, train_number_epochs)) batch_counter = count() for epoch in E: for tss in train_dataloader: batch_i = next(batch_counter) with TorchTrainSession(model): o = [t.to(global_torch_device()) for t in tss] optimiser.zero_grad() loss_contrastive = criterion(model(*o[:2]), o[2].to(dtype=torch.float)) loss_contrastive.backward() optimiser.step() train_loss = loss_contrastive.cpu().item() writer.scalar("train_loss", train_loss, batch_i) if batch_counter.__next__() % validation_interval == 0: with TorchEvalSession(model): for tsv in valid_dataloader: ov = [t.to(global_torch_device()) for t in tsv] v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float) valid_loss = criterion(v_o, fact).cpu().item() valid_accuracy = (accuracy(distances=v_o, is_diff=fact).cpu().item()) writer.scalar("valid_loss", valid_loss, batch_i) if valid_loss < best: best = valid_loss print(f"new best {best}") writer.blip("new_best", batch_i) if save_best: save_model_parameters( model, optimiser=optimiser, model_name=model_name, save_directory=save_path, ) E.set_description( f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}" ) return model
def training( model: Module, data_iterator: Iterator, optimiser: torch.optim.Optimizer, scheduler, writer: Writer, interrupted_path: Path, *, num_updates=2500000, early_stop_threshold=1e-9, denoise: bool = True, ) -> Module: """ :param model: :type model: :param data_iterator: :type data_iterator: :param optimiser: :type optimiser: :param scheduler: :type scheduler: :param writer: :type writer: :param interrupted_path: :type interrupted_path: :param num_updates: :type num_updates: :param early_stop_threshold: :type early_stop_threshold: :return: :rtype:""" best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 1e10 since = time.time() # reraser =RandomErasing() try: sess = tqdm(range(num_updates), leave=False, disable=False) for update_i in sess: for phase in [SplitEnum.training, SplitEnum.validation]: if phase == SplitEnum.training: for param_group in optimiser.param_groups: writer.scalar("lr", param_group["lr"], update_i) model.train() else: model.eval() rgb_imgs, *_ = next(data_iterator) optimiser.zero_grad() with torch.set_grad_enabled(phase == SplitEnum.training): if denoise: # =='denoise': model_input = rgb_imgs + torch.normal( mean=0.0, std=0.1, size=rgb_imgs.shape, device=global_torch_device(), ) # elif recover_type=='missing': # model_input = rgb_imgs else: model_input = rgb_imgs recon_pred, *_ = model(torch.clamp(model_input, 0.0, 1.0)) ret = criterion(recon_pred, rgb_imgs) if phase == SplitEnum.training: ret.backward() optimiser.step() scheduler.step() update_loss = ret.data.cpu().numpy() writer.scalar(f"loss/accum", update_loss, update_i) if phase == SplitEnum.validation and update_loss < best_loss: best_loss = update_loss best_model_wts = copy.deepcopy(model.state_dict()) _format = "NCHW" writer.image(f"rgb_imgs", rgb_imgs, update_i, data_formats=_format) writer.image( f"recon_pred", recon_pred, update_i, data_formats=_format ) sess.write(f"New best model at update {update_i}") sess.set_description_str( f"Update {update_i} - {phase} accum_loss:{update_loss:2f}" ) if update_loss < early_stop_threshold: break except KeyboardInterrupt: print("Interrupt") finally: model.load_state_dict(best_model_wts) # load best model weights torch.save(model.state_dict(), interrupted_path) time_elapsed = time.time() - since print(f"{time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s") print(f"Best val loss: {best_loss}") return model