示例#1
0
def maml_supervised(model,
                    Task,
                    n_epochs,
                    task_batch_n,
                    npts,
                    k_shot,
                    n_gradient_steps,
                    test_fn=None,
                    **_):
    """
    supervised MAML. Task need to implement .proper and .sample methods, where proper is the full,
    dense set of data from which samples are drawn.

    :param model:
    :param Task:
    :param n_epochs:
    :param task_batch_n:
    :param npts: the total number of samples for the sinusoidal task
    :param k_shot:
    :param n_gradient_steps:
    :param _:
    :return:
    """
    import playground.maml.maml_torch.archive.paper_metrics as metrics

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    model.to(device)

    ps = list(model.parameters())

    # for ep_ind in trange(n_epochs, desc='Epochs', ncols=50, leave=False):
    M.tic('epoch')
    for ep_ind in trange(n_epochs):
        M.split('epoch', silent=True)
        meta_grads = defaultdict(lambda: 0)
        theta = copy.deepcopy(model.state_dict())
        tasks = [Task(npts=npts) for _ in range(task_batch_n)]
        for task_ind, task in enumerate(tasks):  # sample a new problem
            # todo: this part is highly-parallelizable
            if task_ind != 0:
                model.load_state_dict(theta)

            task_grads = defaultdict(deque)
            proper = t.tensor(task.proper()).to(device)
            samples = t.tensor(task.samples(k_shot)).to(device)

            for grad_ind in range(n_gradient_steps):
                # done: ready to be repackaged
                loss, _ = metrics.comp_loss(*samples, model)
                model.zero_grad()
                # back-propagate once, retain graph.
                loss.backward(t.ones(1).to(device), retain_graph=True)

                # done: need to use gradient descent, plus creating a meta graph.
                U, grad_outputs = [], []
                for p in model.parameters():
                    U.append(p - G.alpha * p.grad)  # meta update
                    grad_outputs.append(t.ones(1).to(device).expand_as(p))

                # t.autograd.grad returns sum of gradient between all U and all grad_outputs
                # note: this is the row sum of \partial theta_prime \partial theta, which is a matrix.
                dU = t.autograd.grad(outputs=U,
                                     grad_outputs=grad_outputs,
                                     inputs=model.parameters())

                # Now update the param.figs
                for p, updated_p, du in zip(ps, U, dU):
                    p.data = updated_p.data  # these are leaf notes, so we can directly manipulate the data attribute.
                    task_grads[p].append(du)

                # note: evaluate the 1-grad loss
                if G.test_interval and ep_ind % G.test_interval and grad_ind + 1 in G.test_grad_steps:
                    with t.no_grad():
                        if test_fn is not None:
                            test_fn(grad_ind + 1,
                                    model,
                                    task=task,
                                    epoch=ep_ind)
                        _loss, _ = metrics.comp_loss(*proper, model)
                    logger.log_keyvalue(
                        ep_ind,
                        key=f"{grad_ind + 1:d}-grad-loss-{task_ind:02d}",
                        value=_loss.item(),
                        silent=True)

            # compute Loss_theta_prime
            samples = t.tensor(task.samples(k_shot)).to(
                device)  # sample from this problem
            loss, _ = metrics.comp_loss(*samples, model)
            model.zero_grad()
            loss.backward()

            for i, grad in enumerate(
                    model.gradients()):  # Now accumulate the gradient
                p = ps[i]
                task_grads[p].append(grad)
                meta_grads[p] += t.prod(t.cat(list(
                    map(lambda d: d.unsqueeze(dim=-1), task_grads[p])),
                                              dim=-1),
                                        dim=-1)

        # theta_prime = copy.deepcopy(model.state_dict())
        model.load_state_dict(theta)
        for p in ps:
            p.grad = t.tensor(
                (meta_grads[p] / task_batch_n).detach()).to(device)

        model.meta_step(lr=G.beta)

        with t.no_grad():
            if test_fn is not None:
                test_fn(0, model, epoch=ep_ind)
            _loss, _ = metrics.comp_loss(*proper, model)
        # it is very easy to use the wrong loss here.
        logger.log_keyvalue(ep_ind, '0-grad-loss', _loss.item(), silent=True)

        # save model weight
        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})
示例#2
0
def maml(model=None, test_fn=None):
    from playground.maml.maml_torch.tasks import Sine

    model = model or (StandardMLP(1, 1) if G.debug else FunctionalMLP(1, 1))
    meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta)
    mse = t.nn.MSELoss()

    M.tic('start')
    M.tic('epoch')
    for ep_ind in range(G.n_epochs):
        dt = M.split('epoch', silent=True)
        dt_ = M.toc('start', silent=True)
        print(f"epoch {ep_ind} @ {dt:.4f}sec/ep, {dt_:.1f} sec from start")
        original_ps = OrderedDict(model.named_parameters())

        tasks = [Sine() for _ in range(G.task_batch_n)]

        for task_ind, task in enumerate(tasks):

            if task_ind != 0:
                model.params.update(original_ps)
            if G.test_mode:
                _gradient = original_ps['bias_var'].grad
                if task_ind == 0:
                    assert _gradient is None or _gradient.sum().item(
                    ) == 0, f"{_gradient} is not zero or None, epoch {ep_ind}."
                else:
                    assert _gradient.sum().item(
                    ) != 0, f"{_gradient} should be non-zero"
                assert (
                    original_ps['bias_var'] == model.params['bias_var']
                ).all().item() == 1, 'the two parameters should be the same'

            xs, labels = t.DoubleTensor(task.samples(G.k_shot))

            _silent = task_ind != 0

            for grad_ind in range(G.n_gradient_steps):
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None

                loss = mse(ys, labels.unsqueeze(-1))
                logger.log_keyvalue(ep_ind,
                                    f"{task_ind}-grad-{grad_ind}-loss",
                                    loss.item(),
                                    silent=_silent)
                if callable(test_fn) and \
                        ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                    test_fn(model,
                            task=task,
                            task_id=task_ind,
                            epoch=ep_ind,
                            grad_step=grad_ind,
                            silent=_silent,
                            h0=ht)
                dps = t.autograd.grad(loss,
                                      model.parameters(),
                                      create_graph=True,
                                      retain_graph=True)
                # 1. update parameters, use updated theta'.
                # 2. run forward, get direct gradient to update the network
                for (name, p), dp in zip(model.named_parameters(), dps):
                    model.params[name] = p - G.alpha * dp

            grad_ind = G.n_gradient_steps
            # meta gradient
            if hasattr(model, "is_autoregressive") and model.is_autoregressive:
                h0 = model.h0_init()
                ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                               labels.view(G.k_shot, 1, 1))
                ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
            elif hasattr(model, "is_recurrent") and model.is_recurrent:
                h0 = model.h0_init()
                ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                ys = ys.squeeze(-1)  # ys:Size[5, batch_n: 1, 1]
            else:
                ys = model(xs.unsqueeze(-1))
                ht = None
            loss = mse(ys, labels.unsqueeze(-1))
            logger.log_keyvalue(ep_ind,
                                f"{task_ind}-grad-{grad_ind}-loss",
                                loss.item(),
                                silent=_silent)

            if callable(test_fn) and \
                    ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                test_fn(model,
                        task=task,
                        task_id=task_ind,
                        epoch=ep_ind,
                        grad_step=grad_ind,
                        silent=_silent,
                        h0=ht)
            meta_dps = t.autograd.grad(loss, original_ps.values())
            with t.no_grad():
                for (name, p), meta_dp in zip(original_ps.items(), meta_dps):
                    p.grad = (0 if p.grad is None else p.grad) + meta_dp

        # normalize the gradient.
        with t.no_grad():
            for (name, p) in original_ps.items():
                p.grad /= G.task_batch_n

        model.params.update(original_ps)
        meta_optimizer.step()
        meta_optimizer.zero_grad()

        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})

    logger.flush()
示例#3
0
def train():
    from moleskin import moleskin as M

    M.tic('Full Run')
    if G.model == "lenet":
        model = Conv2d()
    elif G.model == 'mlp':
        model = Mlp()
    else:
        raise NotImplementedError('only lenet and mlp are allowed')
    model.train()
    print(model)

    G.log_prefix = f"mnist_{type(model).__name__}"
    logger.configure(log_directory=G.log_dir, prefix=G.log_prefix)
    logger.log_params(G=vars(G), Model=dict(architecture=str(model)))

    from torchvision import datasets, transforms

    trans = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, ), (1.0, ))])
    train_set = datasets.MNIST(root=G.data_dir,
                               train=True,
                               transform=trans,
                               download=True)
    test_set = datasets.MNIST(root=G.data_dir,
                              train=False,
                              transform=trans,
                              download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=G.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=G.batch_size,
                                              shuffle=False)

    celoss = nn.CrossEntropyLoss()
    adam = optim.SGD(model.parameters(), lr=G.learning_rate, momentum=0.9)
    for epoch in range(G.n_epochs):
        for it, (x, target) in enumerate(train_loader):
            adam.zero_grad()
            ys = model(x)
            loss = celoss(ys, target)
            loss.backward()
            adam.step()

            if it % G.test_interval == 0:
                with h.Eval(model), torch.no_grad():
                    accuracy = h.Average()
                    for x, label in test_loader:
                        acc = h.cast(
                            h.one_hot_to_int(model(x).detach()) == label,
                            float).sum() / len(x)
                        accuracy.add(acc.detach().numpy())
                logger.log(float(epoch) + it / len(train_loader),
                           accuracy=accuracy.value)

        M.split("epoch")
        # logger.log(epoch, it=it, loss=loss.detach().numpy())
    M.toc('Full Run')
示例#4
0
def reptile(model=None, test_fn=None):
    from playground.maml.maml_torch.tasks import Sine

    model = model or FunctionalMLP(1, 1)

    meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta)
    mse = t.nn.MSELoss()

    M.tic('epoch')
    for ep_ind in range(G.n_epochs):
        M.split('epoch')
        original_ps = OrderedDict(model.named_parameters())

        tasks = [Sine() for _ in range(G.task_batch_n)]

        for task_ind, task in enumerate(tasks):
            if task_ind != 0:
                model.params.update(original_ps)
            xs, labels = t.DoubleTensor(task.samples(G.k_shot))
            _silent = task_ind != 0
            for grad_ind in range(G.n_gradient_steps):
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None
                loss = mse(ys, labels.unsqueeze(-1))
                with t.no_grad():
                    logger.log_keyvalue(ep_ind,
                                        f"{task_ind}-grad-{grad_ind}-loss",
                                        loss.item(),
                                        silent=_silent)
                    if callable(
                            test_fn
                    ) and ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                        test_fn(model,
                                task,
                                task_id=task_ind,
                                epoch=ep_ind,
                                grad_step=grad_ind,
                                h0=ht,
                                silent=_silent)
                dps = t.autograd.grad(loss, model.parameters())
                with t.no_grad():
                    for (name, p), dp in zip(model.named_parameters(), dps):
                        model.params[name] = p - G.alpha * dp
                        model.params[name].requires_grad = True

            grad_ind = G.n_gradient_steps
            with t.no_grad():
                # domain adaptation
                if hasattr(model,
                           "is_autoregressive") and model.is_autoregressive:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0,
                                   labels.view(G.k_shot, 1, 1))
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                elif hasattr(model, "is_recurrent") and model.is_recurrent:
                    h0 = model.h0_init()
                    ys, ht = model(xs.view(G.k_shot, 1, 1), h0)
                    ys = ys.squeeze(-1)  # ys:Size(5, batch_n:1, 1).
                else:
                    ys = model(xs.unsqueeze(-1))
                    ht = None
                loss = mse(ys, labels.unsqueeze(-1))
                logger.log_keyvalue(ep_ind,
                                    f"{task_ind}-grad-{grad_ind}-loss",
                                    loss.item(),
                                    silent=_silent)

                if callable(test_fn) and \
                        ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps:
                    test_fn(model,
                            task,
                            task_id=task_ind,
                            epoch=ep_ind,
                            grad_step=grad_ind,
                            h0=ht,
                            silent=_silent)

            # Compute REPTILE 1st-order gradient
            with t.no_grad():
                for name, p in original_ps.items():
                    # let's do the division at the end.
                    p.grad = (0 if p.grad is None else p.grad) + (
                        p - model.params[name])  # / G.task_batch_n

        with t.no_grad():
            for name, p in original_ps.items():
                # let's do the division at the end.
                p.grad /= G.task_batch_n

        model.params.update(original_ps)
        meta_optimizer.step()
        meta_optimizer.zero_grad()

        if G.save_interval and ep_ind % G.save_interval == 0:
            logger.log_module(ep_ind, **{type(model).__name__: model})

    logger.flush()
def train():
    from linear_schedule import Linear

    ledger = defaultdict(lambda: MovingAverage(Reporting.reward_average))

    M.config(file=os.path.join(RUN.log_directory, RUN.log_file))
    M.diff()

    with U.make_session(
            RUN.num_cpu), Logger(RUN.log_directory) as logger, contextify(
                gym.make(G.env_name)) as env:
        env = ScaledFloatFrame(wrap_dqn(env))

        if G.seed is not None:
            env.seed(G.seed)
        logger.log_params(G=vars(G), RUN=vars(RUN), Reporting=vars(Reporting))
        inputs = TrainInputs(action_space=env.action_space,
                             observation_space=env.observation_space)
        trainer = QTrainer(inputs=inputs,
                           action_space=env.action_space,
                           observation_space=env.observation_space)
        if G.prioritized_replay:
            replay_buffer = PrioritizedReplayBuffer(size=G.buffer_size,
                                                    alpha=G.alpha)
        else:
            replay_buffer = ReplayBuffer(size=G.buffer_size)

        class schedules:
            # note: it is important to have this start from the begining.
            eps = Linear(G.n_timesteps * G.exploration_fraction, 1,
                         G.final_eps)
            if G.prioritized_replay:
                beta = Linear(G.n_timesteps - G.learning_start, G.beta_start,
                              G.beta_end)

        U.initialize()
        trainer.update_target()
        x = np.array(env.reset())
        ep_ind = 0
        M.tic('episode')
        for t_step in range(G.n_timesteps):
            # schedules
            eps = 0 if G.param_noise else schedules.eps[t_step]
            if G.prioritized_replay:
                beta = schedules.beta[t_step - G.learning_start]

            x0 = x
            M.tic('sample', silent=True)
            (action, *_), action_q, q = trainer.runner.act([x], eps)
            x, rew, done, info = env.step(action)
            ledger['action_q_value'].append(action_q.max())
            ledger['action_q_value/mean'].append(action_q.mean())
            ledger['action_q_value/var'].append(action_q.var())
            ledger['q_value'].append(q.max())
            ledger['q_value/mean'].append(q.mean())
            ledger['q_value/var'].append(q.var())
            ledger['timing/sample'].append(M.toc('sample', silent=True))
            # note: adding sample to the buffer is identical between the prioritized and the standard replay strategy.
            replay_buffer.add(s0=x0,
                              action=action,
                              reward=rew,
                              s1=x,
                              done=float(done))

            logger.log(
                t_step, {
                    'q_value': ledger['q_value'].latest,
                    'q_value/mean': ledger['q_value/mean'].latest,
                    'q_value/var': ledger['q_value/var'].latest,
                    'q_value/action': ledger['action_q_value'].latest,
                    'q_value/action/mean':
                    ledger['action_q_value/mean'].latest,
                    'q_value/action/var': ledger['action_q_value/var'].latest
                },
                action=action,
                eps=eps,
                silent=True)

            if G.prioritized_replay:
                logger.log(t_step, beta=beta, silent=True)

            if done:
                ledger['timing/episode'].append(M.split('episode',
                                                        silent=True))
                ep_ind += 1
                x = np.array(env.reset())
                ledger['rewards'].append(info['total_reward'])

                silent = (ep_ind % Reporting.print_interval != 0)
                logger.log(t_step,
                           timestep=t_step,
                           episode=green(ep_ind),
                           total_reward=ledger['rewards'].latest,
                           episode_length=info['timesteps'],
                           silent=silent)
                logger.log(t_step, {
                    'total_reward/mean':
                    yellow(ledger['rewards'].mean, lambda v: f"{v:.1f}"),
                    'total_reward/max':
                    yellow(ledger['rewards'].max, lambda v: f"{v:.1f}"),
                    "time_spent_exploring":
                    default(eps, percent),
                    "timing/episode":
                    green(ledger['timing/episode'].latest, sec),
                    "timing/episode/mean":
                    green(ledger['timing/episode'].mean, sec),
                },
                           silent=silent)
                try:
                    logger.log(t_step, {
                        "timing/sample":
                        default(ledger['timing/sample'].latest, sec),
                        "timing/sample/mean":
                        default(ledger['timing/sample'].mean, sec),
                        "timing/train":
                        default(ledger['timing/train'].latest, sec),
                        "timing/train/mean":
                        green(ledger['timing/train'].mean, sec),
                        "timing/log_histogram":
                        default(ledger['timing/log_histogram'].latest, sec),
                        "timing/log_histogram/mean":
                        default(ledger['timing/log_histogram'].mean, sec)
                    },
                               silent=silent)
                    if G.prioritized_replay:
                        logger.log(t_step, {
                            "timing/update_priorities":
                            default(ledger['timing/update_priorities'].latest,
                                    sec),
                            "timing/update_priorities/mean":
                            default(ledger['timing/update_priorities'].mean,
                                    sec)
                        },
                                   silent=silent)
                except Exception as e:
                    pass
                if G.prioritized_replay:
                    logger.log(
                        t_step,
                        {"replay_beta": default(beta, lambda v: f"{v:.2f}")},
                        silent=silent)

            # note: learn here.
            if t_step >= G.learning_start and t_step % G.learn_interval == 0:
                if G.prioritized_replay:
                    experiences, weights, indices = replay_buffer.sample(
                        G.replay_batch_size, beta)
                    logger.log_histogram(t_step, weights=weights)
                else:
                    experiences, weights = replay_buffer.sample(
                        G.replay_batch_size), None
                M.tic('train', silent=True)
                x0s, actions, rewards, x1s, dones = zip(*experiences)
                td_error_val, loss_val = trainer.train(s0s=x0s,
                                                       actions=actions,
                                                       rewards=rewards,
                                                       s1s=x1s,
                                                       dones=dones,
                                                       sample_weights=weights)
                ledger['timing/train'].append(M.toc('train', silent=True))
                M.tic('log_histogram', silent=True)
                logger.log_histogram(t_step, td_error=td_error_val)
                ledger['timing/log_histogram'].append(
                    M.toc('log_histogram', silent=True))
                if G.prioritized_replay:
                    M.tic('update_priorities', silent=True)
                    new_priorities = np.abs(td_error_val) + eps
                    replay_buffer.update_priorities(indices, new_priorities)
                    ledger['timing/update_priorities'].append(
                        M.toc('update_priorities', silent=True))

            if t_step % G.target_network_update_interval == 0:
                trainer.update_target()

            if t_step % Reporting.checkpoint_interval == 0:
                U.save_state(os.path.join(RUN.log_directory, RUN.checkpoint))