Beispiel #1
0
def update_value_net(value_net, states, returns, l2_reg):
    optimizer = LBFGS(value_net.parameters(), max_iter=25, history_size=5)

    def closure():
        optimizer.zero_grad()
        values_pred = value_net(states)
        value_loss = (values_pred - returns).pow(2).mean()

        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg
        value_loss.backward()
        return value_loss

    optimizer.step(closure)
Beispiel #2
0
def train():
    content_image, style_image, image_size = load_img_tensor()
    # imshow(content_image)
    # imshow(style_image)
    # input_params = input_image(image_size)
    input_params = nn.Parameter(content_image, requires_grad=True)
    cnn = cnn_loader()
    model, content_losses, style_losses = get_model_losses(
        cnn, content_image, style_image, 1, 1000)
    epoch = [0]
    num_epoches = 100
    optimizer = LBFGS([input_params])
    content_loss_list = []
    style_loss_list = []
    while epoch[0] < num_epoches:

        def closure():
            optimizer.zero_grad()
            model(input_params)
            content_score = 0
            style_score = 0
            for cs in content_losses:
                content_score += cs.loss
            for ss in style_losses:
                style_score += ss.loss
            loss = content_score + style_score
            loss.backward()
            epoch[0] += 1
            if epoch[0] % 50 == 1:
                print('content score: {}, style score: {}'.format(
                    content_score, style_score))
            content_loss_list.append(content_score)
            style_loss_list.append(style_score)
            return loss

        optimizer.step(closure)
    return input_params, content_loss_list, style_loss_list
def L_BFGS(spec,
           transform_fn,
           samples=None,
           init_x0=None,
           maxiter=1000,
           tol=1e-6,
           verbose=1,
           evaiter=10,
           metric='sc',
           **kwargs):
    r"""

    Reconstruct spectrogram phase using `Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other
    Envelope Representations`_, where I directly use the :class:`torch.optim.LBFGS` optimizer provided in PyTorch.
    This method doesn't restrict to traditional short-time Fourier Transform, but any kinds of presentation (ex: Mel-scaled Spectrogram) as
    long as the transform function is differentiable.

    .. _`Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`:
        https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=6949659

    Args:
        spec (Tensor): the input presentation.
        transform_fn: a function that has the form ``spec = transform_fn(x)`` where x is an 1d tensor.
        samples (int, optional): number of samples in time domain. Default: :obj:`None`
        init_x0 (Tensor, optional): an 1d tensor that make use as initial time domain samples. If not provided, will use random
            value tensor with length equal to ``samples``.
        maxiter (int): maximum number of iterations before timing out.
        tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``.
        verbose (bool): whether to be verbose. Default: :obj:`True`
        evaiter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10``
        metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'``
        **kwargs: other arguments that pass to :class:`torch.optim.LBFGS`.

    Returns:
        A 1d tensor converted from the given presentation
    """
    if init_x0 is None:
        init_x0 = spec.new_empty(samples).normal_(std=1e-6)
    x = nn.Parameter(init_x0)
    T = spec

    criterion = nn.MSELoss()
    optimizer = LBFGS([x], **kwargs)

    def closure():
        optimizer.zero_grad()
        V = transform_fn(x)
        loss = criterion(V, T)
        loss.backward()
        return loss

    bar_dict = {}
    if metric == 'snr':
        metric_func = SNR
        bar_dict['SNR'] = 0
        metric = metric.upper()
    elif metric == 'ser':
        metric_func = SER
        bar_dict['SER'] = 0
        metric = metric.upper()
    else:
        metric_func = spectral_convergence
        bar_dict['spectral_convergence'] = 0
        metric = 'spectral_convergence'

    init_loss = None
    with tqdm(total=maxiter, disable=not verbose) as pbar:
        for i in range(maxiter):
            optimizer.step(closure)

            if i % evaiter == evaiter - 1:
                with torch.no_grad():
                    V = transform_fn(x)
                    bar_dict[metric] = metric_func(V, spec).item()
                    l2_loss = criterion(V, spec).item()
                    pbar.set_postfix(**bar_dict, loss=l2_loss)
                    pbar.update(evaiter)

                    if not init_loss:
                        init_loss = l2_loss
                    elif (previous_loss - l2_loss) / init_loss < tol * evaiter:
                        break
                    previous_loss = l2_loss

    return x.detach()