Ejemplo n.º 1
0
def test_module(setup):
    logger.log_module(
        Test=FakeModule,
        step=0,
    )
    sleep(1.0)
Ejemplo n.º 2
0
def maml(model=None, test_fn=None):
    from playground.maml.maml_torch.tasks import Sine

    model = model or (StandardMLP(1, 1) if G.debug else FunctionalMLP(1, 1))
    meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta)
    mse = t.nn.MSELoss()

    M.tic('start')
    M.tic('epoch')
    for ep_ind in range(G.n_epochs):
        dt = M.split('epoch', silent=True)
        dt_ = M.toc('start', silent=True)
        print(f"epoch {ep_ind} @ {dt:.4f}sec/ep, {dt_:.1f} sec from start")
        original_ps = OrderedDict(model.named_parameters())

        tasks = [Sine() for _ in range(G.task_batch_n)]

        for task_ind, task in enumerate(tasks):

            if task_ind != 0:
                model.params.update(original_ps)
            if G.test_mode:
                _gradient = original_ps['bias_var'].grad
                if task_ind == 0:
                    assert _gradient is None or _gradient.sum().item(
                    ) == 0, f"{_gradient} is not zero or None, epoch {ep_ind}."
                else:
                    assert _gradient.sum().item(
                    ) != 0, f"{_gradient} should be non-zero"
                assert (
                    original_ps['bias_var'] == model.params['bias_var']
                ).all().item() == 1, 'the two parameters should be the same'

            xs, labels = t.DoubleTensor(task.samples(G.k_shot))

            _silent = task_ind != 0

            for grad_ind in range(G.n_gradient_steps):
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None

                loss = mse(ys, labels.unsqueeze(-1))
                logger.log_keyvalue(ep_ind,
                                    f"{task_ind}-grad-{grad_ind}-loss",
                                    loss.item(),
                                    silent=_silent)
                if callable(test_fn) and \
                        ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                    test_fn(model,
                            task=task,
                            task_id=task_ind,
                            epoch=ep_ind,
                            grad_step=grad_ind,
                            silent=_silent,
                            h0=ht)
                dps = t.autograd.grad(loss,
                                      model.parameters(),
                                      create_graph=True,
                                      retain_graph=True)
                # 1. update parameters, use updated theta'.
                # 2. run forward, get direct gradient to update the network
                for (name, p), dp in zip(model.named_parameters(), dps):
                    model.params[name] = p - G.alpha * dp

            grad_ind = G.n_gradient_steps
            # meta gradient
            if hasattr(model, "is_autoregressive") and model.is_autoregressive:
                h0 = model.h0_init()
                ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                               labels.view(G.k_shot, 1, 1))
                ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
            elif hasattr(model, "is_recurrent") and model.is_recurrent:
                h0 = model.h0_init()
                ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
            else:
                ys = model(xs.unsqueeze(-1))
                ht = None
            loss = mse(ys, labels.unsqueeze(-1))
            logger.log_keyvalue(ep_ind,
                                f"{task_ind}-grad-{grad_ind}-loss",
                                loss.item(),
                                silent=_silent)

            if callable(test_fn) and \
                    ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                test_fn(model,
                        task=task,
                        task_id=task_ind,
                        epoch=ep_ind,
                        grad_step=grad_ind,
                        silent=_silent,
                        h0=ht)
            meta_dps = t.autograd.grad(loss, original_ps.values())
            with t.no_grad():
                for (name, p), meta_dp in zip(original_ps.items(), meta_dps):
                    p.grad = (0 if p.grad is None else p.grad) + meta_dp

        # normalize the gradient.
        with t.no_grad():
            for (name, p) in original_ps.items():
                p.grad /= G.task_batch_n

        model.params.update(original_ps)
        meta_optimizer.step()
        meta_optimizer.zero_grad()

        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})

    logger.flush()
Ejemplo n.º 3
0
def maml_supervised(model,
                    Task,
                    n_epochs,
                    task_batch_n,
                    npts,
                    k_shot,
                    n_gradient_steps,
                    test_fn=None,
                    **_):
    """
    supervised MAML. Task need to implement .proper and .sample methods, where proper is the full,
    dense set of data from which samples are drawn.

    :param model:
    :param Task:
    :param n_epochs:
    :param task_batch_n:
    :param npts: the total number of samples for the sinusoidal task
    :param k_shot:
    :param n_gradient_steps:
    :param _:
    :return:
    """
    import playground.maml.maml_torch.archive.paper_metrics as metrics

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    model.to(device)

    ps = list(model.parameters())

    # for ep_ind in trange(n_epochs, desc='Epochs', ncols=50, leave=False):
    M.tic('epoch')
    for ep_ind in trange(n_epochs):
        M.split('epoch', silent=True)
        meta_grads = defaultdict(lambda: 0)
        theta = copy.deepcopy(model.state_dict())
        tasks = [Task(npts=npts) for _ in range(task_batch_n)]
        for task_ind, task in enumerate(tasks):  # sample a new problem
            # todo: this part is highly-parallelizable
            if task_ind != 0:
                model.load_state_dict(theta)

            task_grads = defaultdict(deque)
            proper = t.tensor(task.proper()).to(device)
            samples = t.tensor(task.samples(k_shot)).to(device)

            for grad_ind in range(n_gradient_steps):
                # done: ready to be repackaged
                loss, _ = metrics.comp_loss(*samples, model)
                model.zero_grad()
                # back-propagate once, retain graph.
                loss.backward(t.ones(1).to(device), retain_graph=True)

                # done: need to use gradient descent, plus creating a meta graph.
                U, grad_outputs = [], []
                for p in model.parameters():
                    U.append(p - G.alpha * p.grad)  # meta update
                    grad_outputs.append(t.ones(1).to(device).expand_as(p))

                # t.autograd.grad returns sum of gradient between all U and all grad_outputs
                # note: this is the row sum of \partial theta_prime \partial theta, which is a matrix.
                dU = t.autograd.grad(outputs=U,
                                     grad_outputs=grad_outputs,
                                     inputs=model.parameters())

                # Now update the param.figs
                for p, updated_p, du in zip(ps, U, dU):
                    p.data = updated_p.data  # these are leaf notes, so we can directly manipulate the data attribute.
                    task_grads[p].append(du)

                # note: evaluate the 1-grad loss
                if G.test_interval and ep_ind % G.test_interval and grad_ind + 1 in G.test_grad_steps:
                    with t.no_grad():
                        if test_fn is not None:
                            test_fn(grad_ind + 1,
                                    model,
                                    task=task,
                                    epoch=ep_ind)
                        _loss, _ = metrics.comp_loss(*proper, model)
                    logger.log_keyvalue(
                        ep_ind,
                        key=f"{grad_ind + 1:d}-grad-loss-{task_ind:02d}",
                        value=_loss.item(),
                        silent=True)

            # compute Loss_theta_prime
            samples = t.tensor(task.samples(k_shot)).to(
                device)  # sample from this problem
            loss, _ = metrics.comp_loss(*samples, model)
            model.zero_grad()
            loss.backward()

            for i, grad in enumerate(
                    model.gradients()):  # Now accumulate the gradient
                p = ps[i]
                task_grads[p].append(grad)
                meta_grads[p] += t.prod(t.cat(list(
                    map(lambda d: d.unsqueeze(dim=-1), task_grads[p])),
                                              dim=-1),
                                        dim=-1)

        # theta_prime = copy.deepcopy(model.state_dict())
        model.load_state_dict(theta)
        for p in ps:
            p.grad = t.tensor(
                (meta_grads[p] / task_batch_n).detach()).to(device)

        model.meta_step(lr=G.beta)

        with t.no_grad():
            if test_fn is not None:
                test_fn(0, model, epoch=ep_ind)
            _loss, _ = metrics.comp_loss(*proper, model)
        # it is very easy to use the wrong loss here.
        logger.log_keyvalue(ep_ind, '0-grad-loss', _loss.item(), silent=True)

        # save model weight
        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})
Ejemplo n.º 4
0
def reptile(model=None, test_fn=None):
    from playground.maml.maml_torch.tasks import Sine

    model = model or FunctionalMLP(1, 1)

    meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta)
    mse = t.nn.MSELoss()

    M.tic('epoch')
    for ep_ind in range(G.n_epochs):
        M.split('epoch')
        original_ps = OrderedDict(model.named_parameters())

        tasks = [Sine() for _ in range(G.task_batch_n)]

        for task_ind, task in enumerate(tasks):
            if task_ind != 0:
                model.params.update(original_ps)
            xs, labels = t.DoubleTensor(task.samples(G.k_shot))
            _silent = task_ind != 0
            for grad_ind in range(G.n_gradient_steps):
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None
                loss = mse(ys, labels.unsqueeze(-1))
                with t.no_grad():
                    logger.log_keyvalue(ep_ind,
                                        f"{task_ind}-grad-{grad_ind}-loss",
                                        loss.item(),
                                        silent=_silent)
                    if callable(
                            test_fn
                    ) and ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                        test_fn(model,
                                task,
                                task_id=task_ind,
                                epoch=ep_ind,
                                grad_step=grad_ind,
                                h0=ht,
                                silent=_silent)
                dps = t.autograd.grad(loss, model.parameters())
                with t.no_grad():
                    for (name, p), dp in zip(model.named_parameters(), dps):
                        model.params[name] = p - G.alpha * dp
                        model.params[name].requires_grad = True

            grad_ind = G.n_gradient_steps
            with t.no_grad():
                # domain adaptation
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None
                loss = mse(ys, labels.unsqueeze(-1))
                logger.log_keyvalue(ep_ind,
                                    f"{task_ind}-grad-{grad_ind}-loss",
                                    loss.item(),
                                    silent=_silent)

                if callable(test_fn) and \
                        ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                    test_fn(model,
                            task,
                            task_id=task_ind,
                            epoch=ep_ind,
                            grad_step=grad_ind,
                            h0=ht,
                            silent=_silent)

            # Compute REPTILE 1st-order gradient
            with t.no_grad():
                for name, p in original_ps.items():
                    # let's do the division at the end.
                    p.grad = (0 if p.grad is None else p.grad) + (
                        p - model.params[name])  # / G.task_batch_n

        with t.no_grad():
            for name, p in original_ps.items():
                # let's do the division at the end.
                p.grad /= G.task_batch_n

        model.params.update(original_ps)
        meta_optimizer.step()
        meta_optimizer.zero_grad()

        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})

    logger.flush()