def update_value_net(value_net, states, returns, l2_reg): optimizer = LBFGS(value_net.parameters(), max_iter=25, history_size=5) def closure(): optimizer.zero_grad() values_pred = value_net(states) value_loss = (values_pred - returns).pow(2).mean() # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * l2_reg value_loss.backward() return value_loss optimizer.step(closure)
def train(): content_image, style_image, image_size = load_img_tensor() # imshow(content_image) # imshow(style_image) # input_params = input_image(image_size) input_params = nn.Parameter(content_image, requires_grad=True) cnn = cnn_loader() model, content_losses, style_losses = get_model_losses( cnn, content_image, style_image, 1, 1000) epoch = [0] num_epoches = 100 optimizer = LBFGS([input_params]) content_loss_list = [] style_loss_list = [] while epoch[0] < num_epoches: def closure(): optimizer.zero_grad() model(input_params) content_score = 0 style_score = 0 for cs in content_losses: content_score += cs.loss for ss in style_losses: style_score += ss.loss loss = content_score + style_score loss.backward() epoch[0] += 1 if epoch[0] % 50 == 1: print('content score: {}, style score: {}'.format( content_score, style_score)) content_loss_list.append(content_score) style_loss_list.append(style_score) return loss optimizer.step(closure) return input_params, content_loss_list, style_loss_list
def L_BFGS(spec, transform_fn, samples=None, init_x0=None, maxiter=1000, tol=1e-6, verbose=1, evaiter=10, metric='sc', **kwargs): r""" Reconstruct spectrogram phase using `Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`_, where I directly use the :class:`torch.optim.LBFGS` optimizer provided in PyTorch. This method doesn't restrict to traditional short-time Fourier Transform, but any kinds of presentation (ex: Mel-scaled Spectrogram) as long as the transform function is differentiable. .. _`Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=6949659 Args: spec (Tensor): the input presentation. transform_fn: a function that has the form ``spec = transform_fn(x)`` where x is an 1d tensor. samples (int, optional): number of samples in time domain. Default: :obj:`None` init_x0 (Tensor, optional): an 1d tensor that make use as initial time domain samples. If not provided, will use random value tensor with length equal to ``samples``. maxiter (int): maximum number of iterations before timing out. tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``. verbose (bool): whether to be verbose. Default: :obj:`True` evaiter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` **kwargs: other arguments that pass to :class:`torch.optim.LBFGS`. Returns: A 1d tensor converted from the given presentation """ if init_x0 is None: init_x0 = spec.new_empty(samples).normal_(std=1e-6) x = nn.Parameter(init_x0) T = spec criterion = nn.MSELoss() optimizer = LBFGS([x], **kwargs) def closure(): optimizer.zero_grad() V = transform_fn(x) loss = criterion(V, T) loss.backward() return loss bar_dict = {} if metric == 'snr': metric_func = SNR bar_dict['SNR'] = 0 metric = metric.upper() elif metric == 'ser': metric_func = SER bar_dict['SER'] = 0 metric = metric.upper() else: metric_func = spectral_convergence bar_dict['spectral_convergence'] = 0 metric = 'spectral_convergence' init_loss = None with tqdm(total=maxiter, disable=not verbose) as pbar: for i in range(maxiter): optimizer.step(closure) if i % evaiter == evaiter - 1: with torch.no_grad(): V = transform_fn(x) bar_dict[metric] = metric_func(V, spec).item() l2_loss = criterion(V, spec).item() pbar.set_postfix(**bar_dict, loss=l2_loss) pbar.update(evaiter) if not init_loss: init_loss = l2_loss elif (previous_loss - l2_loss) / init_loss < tol * evaiter: break previous_loss = l2_loss return x.detach()