示例#1
0
def run_style_transfer(cnn,
                       normalization,
                       content_img,
                       style_img,
                       input_img,
                       mask_img,
                       num_steps=500,
                       style_weight=100,
                       content_weight=5):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(
        cnn, normalization, style_img, content_img, mask_img)
    optimizer = LBFGS([input_img.requires_grad_()], max_iter=num_steps, lr=1)

    print('Optimizing..')
    run = [0]

    def closure():
        optimizer.zero_grad()
        model(input_img)
        style_score = 0
        content_score = 0

        for sl in style_losses:
            style_score += sl.loss
        for cl in content_losses:
            content_score += cl.loss

        style_score *= style_weight
        content_score *= content_weight

        loss = style_score + content_score
        loss.backward()

        if run[0] % 100 == 0:
            print("run {}:".format(run))
            print('Style Loss : {} Content Loss: {}'.format(
                style_score.item(), content_score.item()))
            # print()
            # plt.figure(figsize = (8, 8))
            #imshow(input_img.clone())
        run[0] += 1

        return style_score + content_score

    optimizer.step(closure)

    # a last correction...
    input_img.data.clamp_(0, 1)

    return input_img
示例#2
0
def train(model, X_u, u, X_f,
          nu=1.0, num_epoch=100,
          device=torch.device('cpu'), optim='LBFGS'):
    model.to(device)
    model.train()
    optimizer = LBFGS(model.parameters(),
                      lr=1.0,
                      max_iter=50000,
                      max_eval=50000,
                      history_size=50,
                      tolerance_grad=1e-5,
                      tolerance_change=1.0 * np.finfo(float).eps,
                      line_search_fn="strong_wolfe")
    mse = nn.MSELoss()
    # training stage
    xts = torch.from_numpy(X_u).float().to(device)
    us = torch.from_numpy(u).float().to(device)

    xs = torch.from_numpy(X_f[:, 0:1]).float().to(device)
    ts = torch.from_numpy(X_f[:, 1:2]).float().to(device)
    xs.requires_grad = True
    ts.requires_grad = True
    iter = 0

    def loss_closure():
        nonlocal iter
        iter = iter + 1

        optimizer.zero_grad()

        zero_grad(xs)
        zero_grad(ts)
        # print(xs.grad)
        # MSE loss of prediction error
        pred_u = model(xts)
        mse_u = mse(pred_u, us)

        # MSE loss of PDE constraint
        f = PDELoss(model, xs, ts, nu)

        mse_f = torch.mean(f ** 2)
        loss = mse_u + mse_f
        loss.backward()

        if iter % 200 == 0:
            print('Iter: {}, total loss: {}, mse_u: {}, mse_f: {}'.
                  format(iter, loss.item(), mse_u.item(), mse_f.item()))
        return loss

    optimizer.step(loss_closure)

    return model
示例#3
0
    def set_temperature(self,
                        logits: torch.Tensor,
                        labels: torch.Tensor,
                        criterion_fn: Callable[[torch.Tensor, torch.Tensor],
                                               Tuple[torch.Tensor, torch.Tensor]],
                        use_gpu: bool,
                        logger: Optional[AzureAndTensorboardLogger] = None) -> float:
        """
        Tune the temperature of the model using the provided logits and labels.
        :param logits: Logits to use to learn the temperature parameter
        :param labels: Labels to use to learn the temperature parameter
        :param criterion_fn: A criterion function s.t: (logits, labels) => (loss, ECE)
        :param use_gpu: If True then GPU will be used otherwise CPU will be used.
        :param logger: If provided, the intermediate loss and ECE values in the optimization will be reported
        :return Optimal temperature value
        """
        if use_gpu:
            logits = logits.cuda()
            labels = labels.cuda()

        # Calculate loss values before scaling
        before_temperature_loss, before_temperature_ece = criterion_fn(logits, labels)
        print('Before temperature scaling - LOSS: {:.3f} ECE: {:.3f}'
              .format(before_temperature_loss.item(), before_temperature_ece.item()))

        # Next: optimize the temperature w.r.t. the provided criterion function
        optimizer = LBFGS([self.temperature], lr=self.temperature_scaling_config.lr,
                                      max_iter=self.temperature_scaling_config.max_iter)

        def eval_criterion() -> torch.Tensor:
            # zero the gradients for the next optimization step
            optimizer.zero_grad()
            loss, ece = criterion_fn(self.temperature_scale(logits), labels)
            if logger:
                logger.log_to_azure_and_tensorboard("Temp_Scale_LOSS", loss.item())
                logger.log_to_azure_and_tensorboard("Temp_Scale_ECE", ece.item())
            loss.backward()
            return loss

        optimizer.step(eval_criterion)  # type: ignore

        after_temperature_loss, after_temperature_ece = criterion_fn(self.temperature_scale(logits), labels)
        print('Optimal temperature: {:.3f}'.format(self.temperature.item()))
        print('After temperature scaling - LOSS: {:.3f} ECE: {:.3f}'
              .format(after_temperature_loss.item(), after_temperature_ece.item()))
        return self.temperature.item()
    def optimize(self, content_tensor, style_desc_dict, steps):
        optimizer = LBFGS([content_tensor], lr=0.8, max_iter=steps)
        self.n_iter = 0

        def closure():
            self.n_iter += 1
            optimizer.zero_grad()
            loss = self.infer_loss(
                content_tensor, style_desc_dict)
            LOGGER.info("Step %d: loss %.2f", self.n_iter, loss)
            loss.backward(retain_graph=True)
            if self.log_interval > 0 and self.n_iter % self.log_interval == 0:
                self.save_images(content_tensor.unsqueeze(0).to(
                    "cpu").detach().numpy(), LOG_DIR / f"{self.n_iter:03d}.jpg")
            return loss

        optimizer.step(closure)
        return content_tensor.unsqueeze(0)
示例#5
0
def lr_many_pytorch_lbfgs(
    x,
    y,
    history_size=10,
    max_iter=100,
    max_ls=25,
    tol=1e-4,
    C=1,
):
    from torch.optim import LBFGS

    model = StackedRegLogitModel(x.shape[0], x.shape[-1], 1, C=C).to(x.device)
    optimizer = LBFGS(
        model.parameters(),
        lr=1,
        history_size=history_size,
        max_iter=max_iter,
        # XXX: Cannot pass max_ls to strong_wolfe
        line_search_fn="strong_wolfe",
        tolerance_change=0,
        tolerance_grad=tol,
    )

    x_var = x.detach()
    x_var.requires_grad_(True)
    y_var = y.detach().float()

    def closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad()
        loss = model.forward_loss(x_var, y_var)
        if torch.is_grad_enabled():
            loss.backward()
        return loss

    optimizer.step(closure)
    state = optimizer.state[next(iter(optimizer.state))]
    weights = []
    biases = []
    for linear in model.linears:
        weights.append(linear.weight.detach())
        biases.append(linear.bias.detach())
    return (torch.stack(weights, axis=0), torch.stack(biases,
                                                      axis=0), state["n_iter"])
示例#6
0
    def transfer(self,
                 content_img_raw,
                 style_img_raw,
                 n_iter,
                 alpha,
                 beta,
                 size,
                 print_every=50):
        content_img = self.transform_from_pil(content_img_raw, size)
        style_img = self.transform_from_pil(style_img_raw, size)
        random_img = Variable(content_img.clone(), requires_grad=True)
        random_img.data.clamp_(0., 1.)

        self.extract_content(content_img)
        self.extract_style(style_img)

        optimizer = LBFGS([random_img])
        itr = [0]
        while itr[0] <= n_iter:

            def closure():
                optimizer.zero_grad()
                Lc, Ls = self(random_img)
                Lc, Ls = Lc * alpha, Ls * beta
                loss = Lc + Ls
                if not itr[0] % print_every:
                    print(
                        "i: %d, loss: %5.3f, content_loss: %5.3f, style_loss: %5.3f"
                        % (itr[0], loss.item(), Lc.item(), Ls.item()))
                loss.backward()
                itr[0] += 1
                return loss

            optimizer.step(closure)
            random_img.data.clamp_(0., 1.)

        return self.transform_to_pil(random_img)
示例#7
0
class TRPO:
    '''
    Optimizes the given policy using Trust Region Policy Optization (Schulman 2015)
    with Generalized Advantage Estimation (Schulman 2016).

    Attributes
    ----------
    policy : torch.nn.Sequential
        the policy to be optimized

    value_fun : torch.nn.Sequential
        the value function to be optimized and used when calculating the advantages

    simulator : Simulator
        the simulator to be used when generating training experiences

    max_kl_div : float
        the maximum kl divergence of the policy before and after each step

    max_value_step : float
        the learning rate for the value function

    vf_iters : int
        the number of times to optimize the value function over each set of
        training experiences

    vf_l2_reg_coef : float
        the regularization term when calculating the L2 loss of the value function

    discount : float
        the coefficient to use when discounting the rewards

    lam : float
        the bias reduction parameter to use when calculating advantages using GAE

    cg_damping : float
        the multiple of the identity matrix to add to the Hessian when calculating
        Hessian-vector products

    cg_max_iters : int
        the maximum number of iterations to use when solving for the optimal
        search direction using the conjugate gradient method

    line_search_coef : float
        the proportion by which to reduce the step length on each iteration of
        the line search

    line_search_max_iters : int
        the maximum number of line search iterations before returning 0.0 as the
        step length

    line_search_accept_ratio : float
        the minimum proportion of error to accept from linear extrapolation when
        doing the line search

    mse_loss : torch.nn.MSELoss
        a MSELoss object used to calculating the value function loss

    value_optimizer : torch.optim.LBFGS
        a LBFGS object used to optimize the value function

    model_name : str
        an identifier for the model to be used when generating filepath names

    continue_from_file : bool
        whether to continue training from a previous saved session

    save_every : int
        the number of training iterations to go between saving the training session

    episode_num : int
        the number of episodes already completed

    elapsed_time : datetime.timedelta
        the elapsed training time so far

    device : torch.device
        device to be used for pytorch tensor operations

    mean_rewards : list
        a list of the mean rewards obtained by the agent for each episode so far

    Methods
    -------
    train(n_episodes)
        train the policy and value function for the n_episodes episodes

    unroll_samples(samples)
        unroll the samples generated by the simulator and return a flattend
        version of all states, actions, rewards, and estimated Q-values

    get_advantages(samples)
        return the GAE advantages and a version of the unrolled states with
        a time variable concatenated to each state

    update_value_fun(states, q_vals)
        calculate one update step and apply it to the value function

    update_policy(states, actions, advantages)
        calculate one update step using TRPO and apply it to the policy

    surrogate_loss(log_action_probs, imp_sample_probs, advantages)
        calculate the loss for the policy on a batch of experiences

    get_max_step_len(search_dir, Hvp_fun, max_step, retain_graph=False)
        calculate the coefficient for search_dir s.t. the change in the function
        approximator of interest will be equal to max_step

    save_session()
        save the current training session

    load_session()
        load a previously saved training session

    print_update()
        print an update message that displays statistics about the most recent
        training iteration
    '''
    def __init__(self,
                 policy,
                 value_fun,
                 simulator,
                 max_kl_div=0.01,
                 max_value_step=0.01,
                 vf_iters=1,
                 vf_l2_reg_coef=1e-3,
                 discount=0.995,
                 lam=0.98,
                 cg_damping=1e-3,
                 cg_max_iters=10,
                 line_search_coef=0.9,
                 line_search_max_iter=10,
                 line_search_accept_ratio=0.1,
                 model_name=None,
                 continue_from_file=False,
                 save_every=1):
        '''
        Parameters
        ----------

        policy : torch.nn.Sequential
            the policy to be optimized

        value_fun : torch.nn.Sequential
            the value function to be optimized and used when calculating the advantages

        simulator : Simulator
            the simulator to be used when generating training experiences

        max_kl_div : float
            the maximum kl divergence of the policy before and after each step
            (default is 0.01)

        max_value_step : float
            the learning rate for the value function (default is 0.01)

        vf_iters : int
            the number of times to optimize the value function over each set of
            training experiences (default is 1)

        vf_l2_reg_coef : float
            the regularization term when calculating the L2 loss of the value function
            (default is 0.001)

        discount : float
            the coefficient to use when discounting the rewards (discount is 0.995)

        lam : float
            the bias reduction parameter to use when calculating advantages using GAE
            (default is 0.98)

        cg_damping : float
            the multiple of the identity matrix to add to the Hessian when calculating
            Hessian-vector products (default is 0.001)

        cg_max_iters : int
            the maximum number of iterations to use when solving for the optimal
            search direction using the conjugate gradient method (default is 10)

        line_search_coef : float
            the proportion by which to reduce the step length on each iteration of
            the line search (default is 0.9)

        line_search_max_iters : int
            the maximum number of line search iterations before returning 0.0 as the
            step length (default is 10)

        line_search_accept_ratio : float
            the minimum proportion of error to accept from linear extrapolation when
            doing the line search (default is 0.1)

        model_name : str
            an identifier for the model to be used when generating filepath names
            (default is None)

        continue_from_file : bool
            whether to continue training from a previous saved session (default is False)

        save_every : int
            the number of training iterations to go between saving the training session
            (default is 1)
        '''

        self.policy = policy
        self.value_fun = value_fun
        self.simulator = simulator
        self.max_kl_div = max_kl_div
        self.max_value_step = max_value_step
        self.vf_iters = vf_iters
        self.vf_l2_reg_coef = vf_l2_reg_coef
        self.discount = discount
        self.lam = lam
        self.cg_damping = cg_damping
        self.cg_max_iters = cg_max_iters
        self.line_search_coef = line_search_coef
        self.line_search_max_iter = line_search_max_iter
        self.line_search_accept_ratio = line_search_accept_ratio
        self.mse_loss = MSELoss(reduction='mean')
        self.value_optimizer = LBFGS(self.value_fun.parameters(),
                                     lr=max_value_step,
                                     max_iter=25)
        self.model_name = model_name
        self.continue_from_file = continue_from_file
        self.save_every = save_every
        self.episode_num = 0
        self.elapsed_time = timedelta(0)
        self.device = get_device()
        self.mean_rewards = []

        if not model_name and continue_from_file:
            raise Exception('Argument continue_from_file to __init__ method of ' \
                            'TRPO case was set to True but model_name was not ' \
                            'specified.')

        if not model_name and save_every:
            raise Exception('Argument save_every to __init__ method of TRPO ' \
                            'was set to a value greater than 0 but model_name ' \
                            'was not specified.')

        if continue_from_file:
            self.load_session()

    def train(self, n_episodes):
        last_q = None
        last_states = None

        while self.episode_num < n_episodes:
            start_time = dt.now()
            self.episode_num += 1

            #在当前参数化的policy下,跑n_trajectories个trajectories
            samples = self.simulator.sample_trajectories()
            states, actions, rewards, q_vals = self.unroll_samples(samples)

            advantages, states_with_time = self.get_advantages(samples)
            advantages -= torch.mean(advantages)
            advantages /= torch.std(advantages)

            #回传sample之下得到的所有states,action,advantages序列,以更新policy的参数
            self.update_policy(states, actions, advantages)

            if last_q is not None:
                self.update_value_fun(
                    torch.cat([states_with_time, last_states]),
                    torch.cat([q_vals, last_q]))
            else:
                self.update_value_fun(states_with_time, q_vals)

            last_q = q_vals
            last_states = states_with_time

            mean_reward = np.mean(
                [np.sum(trajectory['rewards']) for trajectory in samples])
            mean_reward_np = mean_reward
            self.mean_rewards.append(mean_reward_np)
            self.elapsed_time += dt.now() - start_time
            self.print_update()

            if self.save_every and not self.episode_num % self.save_every:
                self.save_session()

    def unroll_samples(self, samples):
        q_vals = []

        for trajectory in samples:
            rewards = torch.tensor(trajectory['rewards'])
            reverse = torch.arange(rewards.size(0) - 1, -1, -1)
            discount_pows = torch.pow(self.discount,
                                      torch.arange(0, rewards.size(0)).float())
            discounted_rewards = rewards * discount_pows
            disc_reward_sums = torch.cumsum(discounted_rewards[reverse],
                                            dim=-1)[reverse]
            trajectory_q_vals = disc_reward_sums / discount_pows
            q_vals.append(trajectory_q_vals)

        states = torch.cat(
            [torch.stack(trajectory['states']) for trajectory in samples])
        actions = torch.cat(
            [torch.stack(trajectory['actions']) for trajectory in samples])
        rewards = torch.cat(
            [torch.stack(trajectory['rewards']) for trajectory in samples])
        q_vals = torch.cat(q_vals)

        return states, actions, rewards, q_vals

    def get_advantages(self, samples):
        advantages = []
        states_with_time = []
        T = self.simulator.trajectory_len

        for trajectory in samples:
            time = torch.arange(0, len(
                trajectory['rewards'])).unsqueeze(1).float() / T
            states = torch.stack(trajectory['states'])
            states = torch.cat([states, time], dim=-1)
            states = states.to(self.device)
            states_with_time.append(states.cpu())
            rewards = torch.tensor(trajectory['rewards'])

            state_values = self.value_fun(states)
            state_values = state_values.view(-1)
            state_values = state_values.cpu()
            state_values_next = torch.cat(
                [state_values[1:], torch.tensor([0.0])])

            td_residuals = rewards + self.discount * state_values_next - state_values
            reverse = torch.arange(rewards.size(0) - 1, -1, -1)
            discount_pows = torch.pow(self.discount * self.lam,
                                      torch.arange(0, rewards.size(0)).float())
            discounted_residuals = td_residuals * discount_pows
            disc_res_sums = torch.cumsum(discounted_residuals[reverse],
                                         dim=-1)[reverse]
            trajectory_advs = disc_res_sums / discount_pows
            advantages.append(trajectory_advs)

        advantages = torch.cat(advantages)

        states_with_time = torch.cat(states_with_time)

        return advantages, states_with_time

    def update_value_fun(self, states, q_vals):
        self.value_fun.train()

        states = states.to(self.device)
        q_vals = q_vals.to(self.device)

        for i in range(self.vf_iters):

            def mse():
                self.value_optimizer.zero_grad()
                state_values = self.value_fun(states).view(-1)

                loss = self.mse_loss(state_values, q_vals)

                flat_params = get_flat_params(self.value_fun)
                l2_loss = self.vf_l2_reg_coef * torch.sum(
                    torch.pow(flat_params, 2))
                loss += l2_loss

                loss.backward()

                return loss

            self.value_optimizer.step(mse)

    def update_policy(self, states, actions, advantages):
        self.policy.train()

        states = states.to(self.device)
        actions = actions.to(self.device)
        advantages = advantages.to(self.device)

        action_dists = self.policy(states)
        log_action_probs = action_dists.log_prob(actions)

        loss = self.surrogate_loss(log_action_probs, log_action_probs.detach(),
                                   advantages)
        loss_grad = flat_grad(loss,
                              self.policy.parameters(),
                              retain_graph=True)

        mean_kl = mean_kl_first_fixed(action_dists, action_dists)

        Fvp_fun = get_Hvp_fun(mean_kl, self.policy.parameters())
        search_dir = cg_solver(Fvp_fun, loss_grad, self.cg_max_iters)

        expected_improvement = torch.matmul(loss_grad, search_dir)

        def constraints_satisfied(step, beta):
            apply_update(self.policy, step)

            with torch.no_grad():
                new_action_dists = self.policy(states)
                new_log_action_probs = new_action_dists.log_prob(actions)

                new_loss = self.surrogate_loss(new_log_action_probs,
                                               log_action_probs, advantages)

                mean_kl = mean_kl_first_fixed(action_dists, new_action_dists)

            actual_improvement = new_loss - loss
            improvement_ratio = actual_improvement / (expected_improvement *
                                                      beta)

            apply_update(self.policy, -step)

            surrogate_cond = improvement_ratio >= self.line_search_accept_ratio and actual_improvement > 0.0
            kl_cond = mean_kl <= self.max_kl_div

            return surrogate_cond and kl_cond

        max_step_len = self.get_max_step_len(search_dir,
                                             Fvp_fun,
                                             self.max_kl_div,
                                             retain_graph=True)
        step_len = line_search(search_dir, max_step_len, constraints_satisfied)

        opt_step = step_len * search_dir
        apply_update(self.policy, opt_step)

    def surrogate_loss(self, log_action_probs, imp_sample_probs, advantages):
        return torch.mean(
            torch.exp(log_action_probs - imp_sample_probs) * advantages)

    def get_max_step_len(self,
                         search_dir,
                         Hvp_fun,
                         max_step,
                         retain_graph=False):
        num = 2 * max_step
        denom = torch.matmul(search_dir, Hvp_fun(search_dir, retain_graph))
        max_step_len = torch.sqrt(num / denom)

        return max_step_len

    def save_session(self):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        save_path = os.path.join(save_dir, self.model_name + '.pt')

        ckpt = {
            'policy_state_dict': self.policy.state_dict(),
            'value_state_dict': self.value_fun.state_dict(),
            'mean_rewards': self.mean_rewards,
            'episode_num': self.episode_num,
            'elapsed_time': self.elapsed_time
        }

        if self.simulator.state_filter:
            ckpt['state_filter'] = self.simulator.state_filter

        torch.save(ckpt, save_path)

    def load_session(self):
        load_path = os.path.join(save_dir, self.model_name + '.pt')
        ckpt = torch.load(load_path)

        self.policy.load_state_dict(ckpt['policy_state_dict'])
        self.value_fun.load_state_dict(ckpt['value_state_dict'])
        self.mean_rewards = ckpt['mean_rewards']
        self.episode_num = ckpt['episode_num']
        self.elapsed_time = ckpt['elapsed_time']

        try:
            self.simulator.state_filter = ckpt['state_filter']
        except KeyError:
            pass

    def print_update(self):
        update_message = '[EPISODE]: {0}\t[AVG. REWARD]: {1:.4f}\t [ELAPSED TIME]: {2}'
        elapsed_time_str = ''.join(str(self.elapsed_time).split('.')[0])
        format_args = (self.episode_num, self.mean_rewards[-1],
                       elapsed_time_str)
        print(update_message.format(*format_args))
        print(torch.sigmoid(logits))

        # bce = loss_fun(logits, y)
        bce = loss_fun(mission_logits, y)

        flat_params = torch.cat(
            [get_flat_params(mission_model),
             get_flat_params(maint_model)])
        l2_loss = l2_reg_coef * torch.sum(torch.pow(flat_params, 2))
        reg_loss = bce + l2_loss

        reg_loss.backward()

        return reg_loss

    optimizer.step(bce_loss)

    with torch.no_grad():
        mission_logits_daily = mission_model(x_mission)
        maint_logits_daily = maint_model(x_maint)

        mission_logits = torch.stack([torch.sum(mission_logits_daily[slice]) \
                          for slice in mission_hist_slices])
        maint_logits = torch.stack([torch.sum(maint_logits_daily[slice]) \
                        for slice in maint_hist_slices])
        logits = mission_logits + maint_logits

        bce = loss_fun(logits, y)
        bce = loss_fun(mission_logits, y)
        train_losses.append(bce.cpu().detach().numpy())
示例#9
0

def l1_loss(x, y):
    return torch.abs(x - y).mean()


while n_iter[0] <= max_iter:

    def closure():
        optimizer.zero_grad()
        style, content = vgg(opt_img, model)
        style_loss = sum(
            alpha * l1_loss(u, v)
            for alpha, u, v in zip(style_weights, style, style_targets))
        content_loss = sum(
            beta * l1_loss(u, v)
            for beta, u, v in zip(content_weights, content, content_targets))
        loss = style_loss + content_loss
        loss.backward()
        n_iter[0] += 1
        if n_iter[0] % show_iter == (show_iter - 1):
            print('Iteration: %d, style loss: %f, content loss: %f' %
                  (n_iter[0] + 1, style_loss.data[0], content_loss.data[0]))
            out_img = postp(opt_img.data[0].cpu().squeeze())
            torchvision.utils.save_image(out_img,
                                         'out_%d.png' % (n_iter[0] + 1))

        return loss

    optimizer.step(closure)
示例#10
0
class LSTMRegressor(nn.Module):
    def __init__(self,
                 input_size,
                 target_size,
                 hidden_size,
                 nb_layers,
                 device='cpu'):
        super(LSTMRegressor, self).__init__()

        if device == 'gpu' and torch.cuda.is_available():
            self.device = torch.device('cuda:0')
        else:
            self.device = torch.device('cpu')

        self.input_size = input_size
        self.target_size = target_size
        self.hidden_size = hidden_size
        self.nb_layers = nb_layers

        self.lstm = nn.LSTM(input_size,
                            hidden_size,
                            nb_layers,
                            batch_first=True).to(self.device)

        self.linear = nn.Linear(hidden_size, target_size).to(self.device)

        self.criterion = nn.MSELoss().to(self.device)
        self.optim = None

        self.input_trans = None
        self.target_trans = None

    @property
    def model(self):
        return self

    def init_hidden(self, batch_size):
        return torch.zeros(self.nb_layers,
                           batch_size,
                           self.hidden_size,
                           dtype=torch.double).to(self.device)

    def forward(self, inputs, hidden=None):
        output, hidden = self.lstm(inputs, hidden)
        output = self.linear(output)
        return output, hidden

    def init_preprocess(self, target, input):
        self.target_trans = StandardScaler()
        self.input_trans = StandardScaler()

        self.target_trans.fit(target.reshape(-1, self.target_size))
        self.input_trans.fit(input.reshape(-1, self.input_size))

    @ensure_args_torch_doubles
    @ensure_args_atleast_3d
    def fit(self,
            target,
            input,
            nb_epochs,
            lr=0.5,
            l2=1e-32,
            verbose=True,
            preprocess=True):

        if preprocess:
            self.init_preprocess(target, input)
            target = transform(target, self.target_trans)
            input = transform(input, self.input_trans)

        target = target.to(self.device)
        input = input.to(self.device)

        self.model.double()

        self.optim = LBFGS(self.parameters(), lr=lr)
        # self.optim = Adam(self.parameters(), lr=lr, weight_decay=l2)

        for n in range(nb_epochs):

            def closure():
                self.optim.zero_grad()
                _output, hidden = self.model(input)
                loss = self.criterion(_output, target)
                loss.backward()
                return loss

            self.optim.step(closure)

            if verbose:
                if n % 10 == 0:
                    output, _ = self.forward(input)
                    print('Epoch: {}/{}.............'.format(n, nb_epochs),
                          end=' ')
                    print("Loss: {:.6f}".format(self.criterion(output,
                                                               target)))

    @ensure_args_torch_doubles
    @ensure_res_numpy_floats
    def predict(self, input, hidden):
        input = transform(input.reshape(-1, 1, self.input_size),
                          self.input_trans)

        with torch.no_grad():
            output, hidden = self.forward(input, hidden)

        output = inverse_transform(output, self.target_trans)
        return output, list(hidden)

    def forcast(self, state, exogenous=None, horizon=1):
        self.device = torch.device('cpu')
        self.model.to(self.device)

        assert exogenous is None

        _hidden = None

        if state.ndim < 3:
            state = atleast_3d(state, self.input_size)

        buffer_size = state.shape[1] - 1
        if buffer_size == 0:
            _state = state
        else:
            for t in range(buffer_size):
                _state, _hidden = self.predict(state[:, t, :], _hidden)

        forcast = [_state]
        for _ in range(horizon):
            _state, _hidden = self.predict(_state[:, -1, :], _hidden)
            forcast.append(_state)

        forcast = np.hstack(forcast)
        return forcast
def _find_rotation_lbfgs(
    X,
    Y,
    tol=1e-6,
    max_iter=100,
    verbose=True,
    center_columns=True,
):
    """
    Finds orthogonal matrix Q, scaling s, and translation b, to

        minimize   sum(norm(X - s * Y @ Q - b)).

    Note that the solution is not in closed form because we are
    minimizing the sum of norms, which is non-trivial given the
    orthogonality constraint on Q. Without the orthogonality
    constraint, the problem can be formulated as a cone program:

        Guoliang Xue & Yinyu Ye (2000). "An Efficient Algorithm for
        Minimizing a Sum of p-Norms." SIAM J. Optim., 10(2), 551–579.

    However, the orthogonality constraint complicates things, so
    we just minimize by gradient methods used in manifold optimization.

        Mario Lezcano-Casado (2019). "Trivializations for gradient-based
        optimization on manifolds." NeurIPS.
    """

    # Convert X and Y to pytorch tensors.
    X = torch.tensor(X)
    Y = torch.tensor(Y)

    # Check inputs.
    m, n = X.shape
    assert Y.shape == X.shape

    # Orthogonal linear transformation.
    Q = nn.Linear(n, n, bias=False)
    geotorch.orthogonal(Q, "weight")
    Q = Q.double()

    # Allow a rigid translation.
    bias = nn.Parameter(torch.zeros(n, dtype=torch.float64))

    # Collect trainable parameters
    trainable_params = list(Q.parameters())

    if center_columns:
        trainable_params.append(bias)

    # Define rotational alignment, and optimizer.
    optimizer = LBFGS(
        trainable_params,
        max_iter=100,  # number of inner iterations.
        line_search_fn="strong_wolfe",
    )

    def closure():
        optimizer.zero_grad()
        loss = torch.mean(torch.norm(X - Q(Y) - bias, dim=1))
        loss.backward()
        return loss

    # Fit parameters.
    converged = False
    itercount = 0
    while (not converged) and (itercount < max_iter):

        # Update parameters.
        new_loss = optimizer.step(closure).item()

        # Check convergence.
        if itercount != 0:
            improvement = (last_loss - new_loss) / last_loss
            converged = improvement < tol

        last_loss = new_loss

        # Display progress.
        itercount += 1
        if verbose:
            print(f"Iter {itercount}: {last_loss}")
            if converged:
                print("Converged!")

    # Extract result in numpy.
    Q_ = Q.weight.detach().numpy()
    bias_ = bias.detach().numpy()

    return Q_, bias_
def train(training_config):
    writer = SummaryWriter(
    )  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare data loader
    train_loader = utils.get_training_data_loader(training_config)

    # prepare neural networks
    transformer_net = TransformerNet().train().to(device)
    perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)

    optimizer = LBFGS(transformer_net.parameters(),
                      line_search_fn='strong_wolfe')

    # Calculate style image's Gram matrices (style representation)
    # Built over feature maps as produced by the perceptual net - VGG16
    style_img_path = os.path.join(training_config['style_images_path'],
                                  training_config['style_img_name'])
    style_img = utils.prepare_img(style_img_path,
                                  target_shape=None,
                                  device=device,
                                  batch_size=training_config['batch_size'])
    style_img_set_of_feature_maps = perceptual_loss_net(style_img)
    target_style_representation = [
        utils.gram_matrix(x) for x in style_img_set_of_feature_maps
    ]

    utils.print_header(training_config)
    # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images
    acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
    ts = time.time()
    for epoch in range(training_config['num_of_epochs']):
        for batch_id, (content_batch, _) in enumerate(train_loader):
            # step1: Feed content batch through transformer net
            content_batch = content_batch.to(device)
            stylized_batch = transformer_net(content_batch)

            # step2: Feed content and stylized batch through perceptual net (VGG16)
            content_batch_set_of_feature_maps = perceptual_loss_net(
                content_batch)
            stylized_batch_set_of_feature_maps = perceptual_loss_net(
                stylized_batch)

            # step3: Calculate content representations and content loss
            target_content_representation = content_batch_set_of_feature_maps.relu2_2
            current_content_representation = stylized_batch_set_of_feature_maps.relu2_2
            content_loss = training_config['content_weight'] * torch.nn.MSELoss(
                reduction='mean')(target_content_representation,
                                  current_content_representation)

            # step4: Calculate style representation and style loss
            style_loss = 0.0
            current_style_representation = [
                utils.gram_matrix(x)
                for x in stylized_batch_set_of_feature_maps
            ]
            for gram_gt, gram_hat in zip(target_style_representation,
                                         current_style_representation):
                style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt,
                                                                 gram_hat)
            style_loss /= len(target_style_representation)
            style_loss *= training_config['style_weight']

            # step5: Calculate total variation loss - enforces image smoothness
            tv_loss = training_config['tv_weight'] * utils.total_variation(
                stylized_batch)

            # step6: Combine losses and do a backprop
            total_loss = content_loss + style_loss + tv_loss
            total_loss.backward()

            def closure():
                nonlocal total_loss
                optimizer.zero_grad()
                return total_loss

            optimizer.step(closure)

            #
            # Logging and checkpoint creation
            #
            acc_content_loss += content_loss.item()
            acc_style_loss += style_loss.item()
            acc_tv_loss += tv_loss.item()

            if training_config['enable_tensorboard']:
                # log scalars
                writer.add_scalar('Loss/content-loss', content_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/style-loss', style_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalar('Loss/tv-loss', tv_loss.item(),
                                  len(train_loader) * epoch + batch_id + 1)
                writer.add_scalars(
                    'Statistics/min-max-mean-median', {
                        'min': torch.min(stylized_batch),
                        'max': torch.max(stylized_batch),
                        'mean': torch.mean(stylized_batch),
                        'median': torch.median(stylized_batch)
                    },
                    len(train_loader) * epoch + batch_id + 1)
                # log stylized image
                if batch_id % training_config['image_log_freq'] == 0:
                    stylized = utils.post_process_image(
                        stylized_batch[0].detach().to('cpu').numpy())
                    stylized = np.moveaxis(
                        stylized, 2, 0)  # writer expects channel first image
                    writer.add_image('stylized_img', stylized,
                                     len(train_loader) * epoch + batch_id + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_id % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'time elapsed={(time.time() - ts) / 60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}'
                )
                acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]

            if training_config['checkpoint_freq'] is not None and (
                    batch_id + 1) % training_config['checkpoint_freq'] == 0:
                training_state = utils.get_training_metadata(training_config)
                training_state["state_dict"] = transformer_net.state_dict()
                training_state["optimizer_state"] = optimizer.state_dict()
                ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth"
                torch.save(
                    training_state,
                    os.path.join(training_config['checkpoints_path'],
                                 ckpt_model_name))

    #
    # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc.
    #
    training_state = utils.get_training_metadata(training_config)
    training_state["state_dict"] = transformer_net.state_dict()
    training_state["optimizer_state"] = optimizer.state_dict()
    model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth"
    torch.save(
        training_state,
        os.path.join(training_config['model_binaries_path'], model_name))