def run_style_transfer(cnn, normalization, content_img, style_img, input_img, mask_img, num_steps=500, style_weight=100, content_weight=5): """Run the style transfer.""" print('Building the style transfer model..') model, style_losses, content_losses = get_style_model_and_losses( cnn, normalization, style_img, content_img, mask_img) optimizer = LBFGS([input_img.requires_grad_()], max_iter=num_steps, lr=1) print('Optimizing..') run = [0] def closure(): optimizer.zero_grad() model(input_img) style_score = 0 content_score = 0 for sl in style_losses: style_score += sl.loss for cl in content_losses: content_score += cl.loss style_score *= style_weight content_score *= content_weight loss = style_score + content_score loss.backward() if run[0] % 100 == 0: print("run {}:".format(run)) print('Style Loss : {} Content Loss: {}'.format( style_score.item(), content_score.item())) # print() # plt.figure(figsize = (8, 8)) #imshow(input_img.clone()) run[0] += 1 return style_score + content_score optimizer.step(closure) # a last correction... input_img.data.clamp_(0, 1) return input_img
def train(model, X_u, u, X_f, nu=1.0, num_epoch=100, device=torch.device('cpu'), optim='LBFGS'): model.to(device) model.train() optimizer = LBFGS(model.parameters(), lr=1.0, max_iter=50000, max_eval=50000, history_size=50, tolerance_grad=1e-5, tolerance_change=1.0 * np.finfo(float).eps, line_search_fn="strong_wolfe") mse = nn.MSELoss() # training stage xts = torch.from_numpy(X_u).float().to(device) us = torch.from_numpy(u).float().to(device) xs = torch.from_numpy(X_f[:, 0:1]).float().to(device) ts = torch.from_numpy(X_f[:, 1:2]).float().to(device) xs.requires_grad = True ts.requires_grad = True iter = 0 def loss_closure(): nonlocal iter iter = iter + 1 optimizer.zero_grad() zero_grad(xs) zero_grad(ts) # print(xs.grad) # MSE loss of prediction error pred_u = model(xts) mse_u = mse(pred_u, us) # MSE loss of PDE constraint f = PDELoss(model, xs, ts, nu) mse_f = torch.mean(f ** 2) loss = mse_u + mse_f loss.backward() if iter % 200 == 0: print('Iter: {}, total loss: {}, mse_u: {}, mse_f: {}'. format(iter, loss.item(), mse_u.item(), mse_f.item())) return loss optimizer.step(loss_closure) return model
def set_temperature(self, logits: torch.Tensor, labels: torch.Tensor, criterion_fn: Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], use_gpu: bool, logger: Optional[AzureAndTensorboardLogger] = None) -> float: """ Tune the temperature of the model using the provided logits and labels. :param logits: Logits to use to learn the temperature parameter :param labels: Labels to use to learn the temperature parameter :param criterion_fn: A criterion function s.t: (logits, labels) => (loss, ECE) :param use_gpu: If True then GPU will be used otherwise CPU will be used. :param logger: If provided, the intermediate loss and ECE values in the optimization will be reported :return Optimal temperature value """ if use_gpu: logits = logits.cuda() labels = labels.cuda() # Calculate loss values before scaling before_temperature_loss, before_temperature_ece = criterion_fn(logits, labels) print('Before temperature scaling - LOSS: {:.3f} ECE: {:.3f}' .format(before_temperature_loss.item(), before_temperature_ece.item())) # Next: optimize the temperature w.r.t. the provided criterion function optimizer = LBFGS([self.temperature], lr=self.temperature_scaling_config.lr, max_iter=self.temperature_scaling_config.max_iter) def eval_criterion() -> torch.Tensor: # zero the gradients for the next optimization step optimizer.zero_grad() loss, ece = criterion_fn(self.temperature_scale(logits), labels) if logger: logger.log_to_azure_and_tensorboard("Temp_Scale_LOSS", loss.item()) logger.log_to_azure_and_tensorboard("Temp_Scale_ECE", ece.item()) loss.backward() return loss optimizer.step(eval_criterion) # type: ignore after_temperature_loss, after_temperature_ece = criterion_fn(self.temperature_scale(logits), labels) print('Optimal temperature: {:.3f}'.format(self.temperature.item())) print('After temperature scaling - LOSS: {:.3f} ECE: {:.3f}' .format(after_temperature_loss.item(), after_temperature_ece.item())) return self.temperature.item()
def optimize(self, content_tensor, style_desc_dict, steps): optimizer = LBFGS([content_tensor], lr=0.8, max_iter=steps) self.n_iter = 0 def closure(): self.n_iter += 1 optimizer.zero_grad() loss = self.infer_loss( content_tensor, style_desc_dict) LOGGER.info("Step %d: loss %.2f", self.n_iter, loss) loss.backward(retain_graph=True) if self.log_interval > 0 and self.n_iter % self.log_interval == 0: self.save_images(content_tensor.unsqueeze(0).to( "cpu").detach().numpy(), LOG_DIR / f"{self.n_iter:03d}.jpg") return loss optimizer.step(closure) return content_tensor.unsqueeze(0)
def lr_many_pytorch_lbfgs( x, y, history_size=10, max_iter=100, max_ls=25, tol=1e-4, C=1, ): from torch.optim import LBFGS model = StackedRegLogitModel(x.shape[0], x.shape[-1], 1, C=C).to(x.device) optimizer = LBFGS( model.parameters(), lr=1, history_size=history_size, max_iter=max_iter, # XXX: Cannot pass max_ls to strong_wolfe line_search_fn="strong_wolfe", tolerance_change=0, tolerance_grad=tol, ) x_var = x.detach() x_var.requires_grad_(True) y_var = y.detach().float() def closure(): if torch.is_grad_enabled(): optimizer.zero_grad() loss = model.forward_loss(x_var, y_var) if torch.is_grad_enabled(): loss.backward() return loss optimizer.step(closure) state = optimizer.state[next(iter(optimizer.state))] weights = [] biases = [] for linear in model.linears: weights.append(linear.weight.detach()) biases.append(linear.bias.detach()) return (torch.stack(weights, axis=0), torch.stack(biases, axis=0), state["n_iter"])
def transfer(self, content_img_raw, style_img_raw, n_iter, alpha, beta, size, print_every=50): content_img = self.transform_from_pil(content_img_raw, size) style_img = self.transform_from_pil(style_img_raw, size) random_img = Variable(content_img.clone(), requires_grad=True) random_img.data.clamp_(0., 1.) self.extract_content(content_img) self.extract_style(style_img) optimizer = LBFGS([random_img]) itr = [0] while itr[0] <= n_iter: def closure(): optimizer.zero_grad() Lc, Ls = self(random_img) Lc, Ls = Lc * alpha, Ls * beta loss = Lc + Ls if not itr[0] % print_every: print( "i: %d, loss: %5.3f, content_loss: %5.3f, style_loss: %5.3f" % (itr[0], loss.item(), Lc.item(), Ls.item())) loss.backward() itr[0] += 1 return loss optimizer.step(closure) random_img.data.clamp_(0., 1.) return self.transform_to_pil(random_img)
class TRPO: ''' Optimizes the given policy using Trust Region Policy Optization (Schulman 2015) with Generalized Advantage Estimation (Schulman 2016). Attributes ---------- policy : torch.nn.Sequential the policy to be optimized value_fun : torch.nn.Sequential the value function to be optimized and used when calculating the advantages simulator : Simulator the simulator to be used when generating training experiences max_kl_div : float the maximum kl divergence of the policy before and after each step max_value_step : float the learning rate for the value function vf_iters : int the number of times to optimize the value function over each set of training experiences vf_l2_reg_coef : float the regularization term when calculating the L2 loss of the value function discount : float the coefficient to use when discounting the rewards lam : float the bias reduction parameter to use when calculating advantages using GAE cg_damping : float the multiple of the identity matrix to add to the Hessian when calculating Hessian-vector products cg_max_iters : int the maximum number of iterations to use when solving for the optimal search direction using the conjugate gradient method line_search_coef : float the proportion by which to reduce the step length on each iteration of the line search line_search_max_iters : int the maximum number of line search iterations before returning 0.0 as the step length line_search_accept_ratio : float the minimum proportion of error to accept from linear extrapolation when doing the line search mse_loss : torch.nn.MSELoss a MSELoss object used to calculating the value function loss value_optimizer : torch.optim.LBFGS a LBFGS object used to optimize the value function model_name : str an identifier for the model to be used when generating filepath names continue_from_file : bool whether to continue training from a previous saved session save_every : int the number of training iterations to go between saving the training session episode_num : int the number of episodes already completed elapsed_time : datetime.timedelta the elapsed training time so far device : torch.device device to be used for pytorch tensor operations mean_rewards : list a list of the mean rewards obtained by the agent for each episode so far Methods ------- train(n_episodes) train the policy and value function for the n_episodes episodes unroll_samples(samples) unroll the samples generated by the simulator and return a flattend version of all states, actions, rewards, and estimated Q-values get_advantages(samples) return the GAE advantages and a version of the unrolled states with a time variable concatenated to each state update_value_fun(states, q_vals) calculate one update step and apply it to the value function update_policy(states, actions, advantages) calculate one update step using TRPO and apply it to the policy surrogate_loss(log_action_probs, imp_sample_probs, advantages) calculate the loss for the policy on a batch of experiences get_max_step_len(search_dir, Hvp_fun, max_step, retain_graph=False) calculate the coefficient for search_dir s.t. the change in the function approximator of interest will be equal to max_step save_session() save the current training session load_session() load a previously saved training session print_update() print an update message that displays statistics about the most recent training iteration ''' def __init__(self, policy, value_fun, simulator, max_kl_div=0.01, max_value_step=0.01, vf_iters=1, vf_l2_reg_coef=1e-3, discount=0.995, lam=0.98, cg_damping=1e-3, cg_max_iters=10, line_search_coef=0.9, line_search_max_iter=10, line_search_accept_ratio=0.1, model_name=None, continue_from_file=False, save_every=1): ''' Parameters ---------- policy : torch.nn.Sequential the policy to be optimized value_fun : torch.nn.Sequential the value function to be optimized and used when calculating the advantages simulator : Simulator the simulator to be used when generating training experiences max_kl_div : float the maximum kl divergence of the policy before and after each step (default is 0.01) max_value_step : float the learning rate for the value function (default is 0.01) vf_iters : int the number of times to optimize the value function over each set of training experiences (default is 1) vf_l2_reg_coef : float the regularization term when calculating the L2 loss of the value function (default is 0.001) discount : float the coefficient to use when discounting the rewards (discount is 0.995) lam : float the bias reduction parameter to use when calculating advantages using GAE (default is 0.98) cg_damping : float the multiple of the identity matrix to add to the Hessian when calculating Hessian-vector products (default is 0.001) cg_max_iters : int the maximum number of iterations to use when solving for the optimal search direction using the conjugate gradient method (default is 10) line_search_coef : float the proportion by which to reduce the step length on each iteration of the line search (default is 0.9) line_search_max_iters : int the maximum number of line search iterations before returning 0.0 as the step length (default is 10) line_search_accept_ratio : float the minimum proportion of error to accept from linear extrapolation when doing the line search (default is 0.1) model_name : str an identifier for the model to be used when generating filepath names (default is None) continue_from_file : bool whether to continue training from a previous saved session (default is False) save_every : int the number of training iterations to go between saving the training session (default is 1) ''' self.policy = policy self.value_fun = value_fun self.simulator = simulator self.max_kl_div = max_kl_div self.max_value_step = max_value_step self.vf_iters = vf_iters self.vf_l2_reg_coef = vf_l2_reg_coef self.discount = discount self.lam = lam self.cg_damping = cg_damping self.cg_max_iters = cg_max_iters self.line_search_coef = line_search_coef self.line_search_max_iter = line_search_max_iter self.line_search_accept_ratio = line_search_accept_ratio self.mse_loss = MSELoss(reduction='mean') self.value_optimizer = LBFGS(self.value_fun.parameters(), lr=max_value_step, max_iter=25) self.model_name = model_name self.continue_from_file = continue_from_file self.save_every = save_every self.episode_num = 0 self.elapsed_time = timedelta(0) self.device = get_device() self.mean_rewards = [] if not model_name and continue_from_file: raise Exception('Argument continue_from_file to __init__ method of ' \ 'TRPO case was set to True but model_name was not ' \ 'specified.') if not model_name and save_every: raise Exception('Argument save_every to __init__ method of TRPO ' \ 'was set to a value greater than 0 but model_name ' \ 'was not specified.') if continue_from_file: self.load_session() def train(self, n_episodes): last_q = None last_states = None while self.episode_num < n_episodes: start_time = dt.now() self.episode_num += 1 #在当前参数化的policy下,跑n_trajectories个trajectories samples = self.simulator.sample_trajectories() states, actions, rewards, q_vals = self.unroll_samples(samples) advantages, states_with_time = self.get_advantages(samples) advantages -= torch.mean(advantages) advantages /= torch.std(advantages) #回传sample之下得到的所有states,action,advantages序列,以更新policy的参数 self.update_policy(states, actions, advantages) if last_q is not None: self.update_value_fun( torch.cat([states_with_time, last_states]), torch.cat([q_vals, last_q])) else: self.update_value_fun(states_with_time, q_vals) last_q = q_vals last_states = states_with_time mean_reward = np.mean( [np.sum(trajectory['rewards']) for trajectory in samples]) mean_reward_np = mean_reward self.mean_rewards.append(mean_reward_np) self.elapsed_time += dt.now() - start_time self.print_update() if self.save_every and not self.episode_num % self.save_every: self.save_session() def unroll_samples(self, samples): q_vals = [] for trajectory in samples: rewards = torch.tensor(trajectory['rewards']) reverse = torch.arange(rewards.size(0) - 1, -1, -1) discount_pows = torch.pow(self.discount, torch.arange(0, rewards.size(0)).float()) discounted_rewards = rewards * discount_pows disc_reward_sums = torch.cumsum(discounted_rewards[reverse], dim=-1)[reverse] trajectory_q_vals = disc_reward_sums / discount_pows q_vals.append(trajectory_q_vals) states = torch.cat( [torch.stack(trajectory['states']) for trajectory in samples]) actions = torch.cat( [torch.stack(trajectory['actions']) for trajectory in samples]) rewards = torch.cat( [torch.stack(trajectory['rewards']) for trajectory in samples]) q_vals = torch.cat(q_vals) return states, actions, rewards, q_vals def get_advantages(self, samples): advantages = [] states_with_time = [] T = self.simulator.trajectory_len for trajectory in samples: time = torch.arange(0, len( trajectory['rewards'])).unsqueeze(1).float() / T states = torch.stack(trajectory['states']) states = torch.cat([states, time], dim=-1) states = states.to(self.device) states_with_time.append(states.cpu()) rewards = torch.tensor(trajectory['rewards']) state_values = self.value_fun(states) state_values = state_values.view(-1) state_values = state_values.cpu() state_values_next = torch.cat( [state_values[1:], torch.tensor([0.0])]) td_residuals = rewards + self.discount * state_values_next - state_values reverse = torch.arange(rewards.size(0) - 1, -1, -1) discount_pows = torch.pow(self.discount * self.lam, torch.arange(0, rewards.size(0)).float()) discounted_residuals = td_residuals * discount_pows disc_res_sums = torch.cumsum(discounted_residuals[reverse], dim=-1)[reverse] trajectory_advs = disc_res_sums / discount_pows advantages.append(trajectory_advs) advantages = torch.cat(advantages) states_with_time = torch.cat(states_with_time) return advantages, states_with_time def update_value_fun(self, states, q_vals): self.value_fun.train() states = states.to(self.device) q_vals = q_vals.to(self.device) for i in range(self.vf_iters): def mse(): self.value_optimizer.zero_grad() state_values = self.value_fun(states).view(-1) loss = self.mse_loss(state_values, q_vals) flat_params = get_flat_params(self.value_fun) l2_loss = self.vf_l2_reg_coef * torch.sum( torch.pow(flat_params, 2)) loss += l2_loss loss.backward() return loss self.value_optimizer.step(mse) def update_policy(self, states, actions, advantages): self.policy.train() states = states.to(self.device) actions = actions.to(self.device) advantages = advantages.to(self.device) action_dists = self.policy(states) log_action_probs = action_dists.log_prob(actions) loss = self.surrogate_loss(log_action_probs, log_action_probs.detach(), advantages) loss_grad = flat_grad(loss, self.policy.parameters(), retain_graph=True) mean_kl = mean_kl_first_fixed(action_dists, action_dists) Fvp_fun = get_Hvp_fun(mean_kl, self.policy.parameters()) search_dir = cg_solver(Fvp_fun, loss_grad, self.cg_max_iters) expected_improvement = torch.matmul(loss_grad, search_dir) def constraints_satisfied(step, beta): apply_update(self.policy, step) with torch.no_grad(): new_action_dists = self.policy(states) new_log_action_probs = new_action_dists.log_prob(actions) new_loss = self.surrogate_loss(new_log_action_probs, log_action_probs, advantages) mean_kl = mean_kl_first_fixed(action_dists, new_action_dists) actual_improvement = new_loss - loss improvement_ratio = actual_improvement / (expected_improvement * beta) apply_update(self.policy, -step) surrogate_cond = improvement_ratio >= self.line_search_accept_ratio and actual_improvement > 0.0 kl_cond = mean_kl <= self.max_kl_div return surrogate_cond and kl_cond max_step_len = self.get_max_step_len(search_dir, Fvp_fun, self.max_kl_div, retain_graph=True) step_len = line_search(search_dir, max_step_len, constraints_satisfied) opt_step = step_len * search_dir apply_update(self.policy, opt_step) def surrogate_loss(self, log_action_probs, imp_sample_probs, advantages): return torch.mean( torch.exp(log_action_probs - imp_sample_probs) * advantages) def get_max_step_len(self, search_dir, Hvp_fun, max_step, retain_graph=False): num = 2 * max_step denom = torch.matmul(search_dir, Hvp_fun(search_dir, retain_graph)) max_step_len = torch.sqrt(num / denom) return max_step_len def save_session(self): if not os.path.exists(save_dir): os.mkdir(save_dir) save_path = os.path.join(save_dir, self.model_name + '.pt') ckpt = { 'policy_state_dict': self.policy.state_dict(), 'value_state_dict': self.value_fun.state_dict(), 'mean_rewards': self.mean_rewards, 'episode_num': self.episode_num, 'elapsed_time': self.elapsed_time } if self.simulator.state_filter: ckpt['state_filter'] = self.simulator.state_filter torch.save(ckpt, save_path) def load_session(self): load_path = os.path.join(save_dir, self.model_name + '.pt') ckpt = torch.load(load_path) self.policy.load_state_dict(ckpt['policy_state_dict']) self.value_fun.load_state_dict(ckpt['value_state_dict']) self.mean_rewards = ckpt['mean_rewards'] self.episode_num = ckpt['episode_num'] self.elapsed_time = ckpt['elapsed_time'] try: self.simulator.state_filter = ckpt['state_filter'] except KeyError: pass def print_update(self): update_message = '[EPISODE]: {0}\t[AVG. REWARD]: {1:.4f}\t [ELAPSED TIME]: {2}' elapsed_time_str = ''.join(str(self.elapsed_time).split('.')[0]) format_args = (self.episode_num, self.mean_rewards[-1], elapsed_time_str) print(update_message.format(*format_args))
print(torch.sigmoid(logits)) # bce = loss_fun(logits, y) bce = loss_fun(mission_logits, y) flat_params = torch.cat( [get_flat_params(mission_model), get_flat_params(maint_model)]) l2_loss = l2_reg_coef * torch.sum(torch.pow(flat_params, 2)) reg_loss = bce + l2_loss reg_loss.backward() return reg_loss optimizer.step(bce_loss) with torch.no_grad(): mission_logits_daily = mission_model(x_mission) maint_logits_daily = maint_model(x_maint) mission_logits = torch.stack([torch.sum(mission_logits_daily[slice]) \ for slice in mission_hist_slices]) maint_logits = torch.stack([torch.sum(maint_logits_daily[slice]) \ for slice in maint_hist_slices]) logits = mission_logits + maint_logits bce = loss_fun(logits, y) bce = loss_fun(mission_logits, y) train_losses.append(bce.cpu().detach().numpy())
def l1_loss(x, y): return torch.abs(x - y).mean() while n_iter[0] <= max_iter: def closure(): optimizer.zero_grad() style, content = vgg(opt_img, model) style_loss = sum( alpha * l1_loss(u, v) for alpha, u, v in zip(style_weights, style, style_targets)) content_loss = sum( beta * l1_loss(u, v) for beta, u, v in zip(content_weights, content, content_targets)) loss = style_loss + content_loss loss.backward() n_iter[0] += 1 if n_iter[0] % show_iter == (show_iter - 1): print('Iteration: %d, style loss: %f, content loss: %f' % (n_iter[0] + 1, style_loss.data[0], content_loss.data[0])) out_img = postp(opt_img.data[0].cpu().squeeze()) torchvision.utils.save_image(out_img, 'out_%d.png' % (n_iter[0] + 1)) return loss optimizer.step(closure)
class LSTMRegressor(nn.Module): def __init__(self, input_size, target_size, hidden_size, nb_layers, device='cpu'): super(LSTMRegressor, self).__init__() if device == 'gpu' and torch.cuda.is_available(): self.device = torch.device('cuda:0') else: self.device = torch.device('cpu') self.input_size = input_size self.target_size = target_size self.hidden_size = hidden_size self.nb_layers = nb_layers self.lstm = nn.LSTM(input_size, hidden_size, nb_layers, batch_first=True).to(self.device) self.linear = nn.Linear(hidden_size, target_size).to(self.device) self.criterion = nn.MSELoss().to(self.device) self.optim = None self.input_trans = None self.target_trans = None @property def model(self): return self def init_hidden(self, batch_size): return torch.zeros(self.nb_layers, batch_size, self.hidden_size, dtype=torch.double).to(self.device) def forward(self, inputs, hidden=None): output, hidden = self.lstm(inputs, hidden) output = self.linear(output) return output, hidden def init_preprocess(self, target, input): self.target_trans = StandardScaler() self.input_trans = StandardScaler() self.target_trans.fit(target.reshape(-1, self.target_size)) self.input_trans.fit(input.reshape(-1, self.input_size)) @ensure_args_torch_doubles @ensure_args_atleast_3d def fit(self, target, input, nb_epochs, lr=0.5, l2=1e-32, verbose=True, preprocess=True): if preprocess: self.init_preprocess(target, input) target = transform(target, self.target_trans) input = transform(input, self.input_trans) target = target.to(self.device) input = input.to(self.device) self.model.double() self.optim = LBFGS(self.parameters(), lr=lr) # self.optim = Adam(self.parameters(), lr=lr, weight_decay=l2) for n in range(nb_epochs): def closure(): self.optim.zero_grad() _output, hidden = self.model(input) loss = self.criterion(_output, target) loss.backward() return loss self.optim.step(closure) if verbose: if n % 10 == 0: output, _ = self.forward(input) print('Epoch: {}/{}.............'.format(n, nb_epochs), end=' ') print("Loss: {:.6f}".format(self.criterion(output, target))) @ensure_args_torch_doubles @ensure_res_numpy_floats def predict(self, input, hidden): input = transform(input.reshape(-1, 1, self.input_size), self.input_trans) with torch.no_grad(): output, hidden = self.forward(input, hidden) output = inverse_transform(output, self.target_trans) return output, list(hidden) def forcast(self, state, exogenous=None, horizon=1): self.device = torch.device('cpu') self.model.to(self.device) assert exogenous is None _hidden = None if state.ndim < 3: state = atleast_3d(state, self.input_size) buffer_size = state.shape[1] - 1 if buffer_size == 0: _state = state else: for t in range(buffer_size): _state, _hidden = self.predict(state[:, t, :], _hidden) forcast = [_state] for _ in range(horizon): _state, _hidden = self.predict(_state[:, -1, :], _hidden) forcast.append(_state) forcast = np.hstack(forcast) return forcast
def _find_rotation_lbfgs( X, Y, tol=1e-6, max_iter=100, verbose=True, center_columns=True, ): """ Finds orthogonal matrix Q, scaling s, and translation b, to minimize sum(norm(X - s * Y @ Q - b)). Note that the solution is not in closed form because we are minimizing the sum of norms, which is non-trivial given the orthogonality constraint on Q. Without the orthogonality constraint, the problem can be formulated as a cone program: Guoliang Xue & Yinyu Ye (2000). "An Efficient Algorithm for Minimizing a Sum of p-Norms." SIAM J. Optim., 10(2), 551–579. However, the orthogonality constraint complicates things, so we just minimize by gradient methods used in manifold optimization. Mario Lezcano-Casado (2019). "Trivializations for gradient-based optimization on manifolds." NeurIPS. """ # Convert X and Y to pytorch tensors. X = torch.tensor(X) Y = torch.tensor(Y) # Check inputs. m, n = X.shape assert Y.shape == X.shape # Orthogonal linear transformation. Q = nn.Linear(n, n, bias=False) geotorch.orthogonal(Q, "weight") Q = Q.double() # Allow a rigid translation. bias = nn.Parameter(torch.zeros(n, dtype=torch.float64)) # Collect trainable parameters trainable_params = list(Q.parameters()) if center_columns: trainable_params.append(bias) # Define rotational alignment, and optimizer. optimizer = LBFGS( trainable_params, max_iter=100, # number of inner iterations. line_search_fn="strong_wolfe", ) def closure(): optimizer.zero_grad() loss = torch.mean(torch.norm(X - Q(Y) - bias, dim=1)) loss.backward() return loss # Fit parameters. converged = False itercount = 0 while (not converged) and (itercount < max_iter): # Update parameters. new_loss = optimizer.step(closure).item() # Check convergence. if itercount != 0: improvement = (last_loss - new_loss) / last_loss converged = improvement < tol last_loss = new_loss # Display progress. itercount += 1 if verbose: print(f"Iter {itercount}: {last_loss}") if converged: print("Converged!") # Extract result in numpy. Q_ = Q.weight.detach().numpy() bias_ = bias.detach().numpy() return Q_, bias_
def train(training_config): writer = SummaryWriter( ) # (tensorboard) writer will output to ./runs/ directory by default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # prepare data loader train_loader = utils.get_training_data_loader(training_config) # prepare neural networks transformer_net = TransformerNet().train().to(device) perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device) optimizer = LBFGS(transformer_net.parameters(), line_search_fn='strong_wolfe') # Calculate style image's Gram matrices (style representation) # Built over feature maps as produced by the perceptual net - VGG16 style_img_path = os.path.join(training_config['style_images_path'], training_config['style_img_name']) style_img = utils.prepare_img(style_img_path, target_shape=None, device=device, batch_size=training_config['batch_size']) style_img_set_of_feature_maps = perceptual_loss_net(style_img) target_style_representation = [ utils.gram_matrix(x) for x in style_img_set_of_feature_maps ] utils.print_header(training_config) # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] ts = time.time() for epoch in range(training_config['num_of_epochs']): for batch_id, (content_batch, _) in enumerate(train_loader): # step1: Feed content batch through transformer net content_batch = content_batch.to(device) stylized_batch = transformer_net(content_batch) # step2: Feed content and stylized batch through perceptual net (VGG16) content_batch_set_of_feature_maps = perceptual_loss_net( content_batch) stylized_batch_set_of_feature_maps = perceptual_loss_net( stylized_batch) # step3: Calculate content representations and content loss target_content_representation = content_batch_set_of_feature_maps.relu2_2 current_content_representation = stylized_batch_set_of_feature_maps.relu2_2 content_loss = training_config['content_weight'] * torch.nn.MSELoss( reduction='mean')(target_content_representation, current_content_representation) # step4: Calculate style representation and style loss style_loss = 0.0 current_style_representation = [ utils.gram_matrix(x) for x in stylized_batch_set_of_feature_maps ] for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt, gram_hat) style_loss /= len(target_style_representation) style_loss *= training_config['style_weight'] # step5: Calculate total variation loss - enforces image smoothness tv_loss = training_config['tv_weight'] * utils.total_variation( stylized_batch) # step6: Combine losses and do a backprop total_loss = content_loss + style_loss + tv_loss total_loss.backward() def closure(): nonlocal total_loss optimizer.zero_grad() return total_loss optimizer.step(closure) # # Logging and checkpoint creation # acc_content_loss += content_loss.item() acc_style_loss += style_loss.item() acc_tv_loss += tv_loss.item() if training_config['enable_tensorboard']: # log scalars writer.add_scalar('Loss/content-loss', content_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalar('Loss/style-loss', style_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalar('Loss/tv-loss', tv_loss.item(), len(train_loader) * epoch + batch_id + 1) writer.add_scalars( 'Statistics/min-max-mean-median', { 'min': torch.min(stylized_batch), 'max': torch.max(stylized_batch), 'mean': torch.mean(stylized_batch), 'median': torch.median(stylized_batch) }, len(train_loader) * epoch + batch_id + 1) # log stylized image if batch_id % training_config['image_log_freq'] == 0: stylized = utils.post_process_image( stylized_batch[0].detach().to('cpu').numpy()) stylized = np.moveaxis( stylized, 2, 0) # writer expects channel first image writer.add_image('stylized_img', stylized, len(train_loader) * epoch + batch_id + 1) if training_config[ 'console_log_freq'] is not None and batch_id % training_config[ 'console_log_freq'] == 0: print( f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}' ) acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] if training_config['checkpoint_freq'] is not None and ( batch_id + 1) % training_config['checkpoint_freq'] == 0: training_state = utils.get_training_metadata(training_config) training_state["state_dict"] = transformer_net.state_dict() training_state["optimizer_state"] = optimizer.state_dict() ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth" torch.save( training_state, os.path.join(training_config['checkpoints_path'], ckpt_model_name)) # # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc. # training_state = utils.get_training_metadata(training_config) training_state["state_dict"] = transformer_net.state_dict() training_state["optimizer_state"] = optimizer.state_dict() model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth" torch.save( training_state, os.path.join(training_config['model_binaries_path'], model_name))