def maml_supervised(model, Task, n_epochs, task_batch_n, npts, k_shot, n_gradient_steps, test_fn=None, **_): """ supervised MAML. Task need to implement .proper and .sample methods, where proper is the full, dense set of data from which samples are drawn. :param model: :param Task: :param n_epochs: :param task_batch_n: :param npts: the total number of samples for the sinusoidal task :param k_shot: :param n_gradient_steps: :param _: :return: """ import playground.maml.maml_torch.archive.paper_metrics as metrics device = t.device('cuda' if t.cuda.is_available() else 'cpu') model.to(device) ps = list(model.parameters()) # for ep_ind in trange(n_epochs, desc='Epochs', ncols=50, leave=False): M.tic('epoch') for ep_ind in trange(n_epochs): M.split('epoch', silent=True) meta_grads = defaultdict(lambda: 0) theta = copy.deepcopy(model.state_dict()) tasks = [Task(npts=npts) for _ in range(task_batch_n)] for task_ind, task in enumerate(tasks): # sample a new problem # todo: this part is highly-parallelizable if task_ind != 0: model.load_state_dict(theta) task_grads = defaultdict(deque) proper = t.tensor(task.proper()).to(device) samples = t.tensor(task.samples(k_shot)).to(device) for grad_ind in range(n_gradient_steps): # done: ready to be repackaged loss, _ = metrics.comp_loss(*samples, model) model.zero_grad() # back-propagate once, retain graph. loss.backward(t.ones(1).to(device), retain_graph=True) # done: need to use gradient descent, plus creating a meta graph. U, grad_outputs = [], [] for p in model.parameters(): U.append(p - G.alpha * p.grad) # meta update grad_outputs.append(t.ones(1).to(device).expand_as(p)) # t.autograd.grad returns sum of gradient between all U and all grad_outputs # note: this is the row sum of \partial theta_prime \partial theta, which is a matrix. dU = t.autograd.grad(outputs=U, grad_outputs=grad_outputs, inputs=model.parameters()) # Now update the param.figs for p, updated_p, du in zip(ps, U, dU): p.data = updated_p.data # these are leaf notes, so we can directly manipulate the data attribute. task_grads[p].append(du) # note: evaluate the 1-grad loss if G.test_interval and ep_ind % G.test_interval and grad_ind + 1 in G.test_grad_steps: with t.no_grad(): if test_fn is not None: test_fn(grad_ind + 1, model, task=task, epoch=ep_ind) _loss, _ = metrics.comp_loss(*proper, model) logger.log_keyvalue( ep_ind, key=f"{grad_ind + 1:d}-grad-loss-{task_ind:02d}", value=_loss.item(), silent=True) # compute Loss_theta_prime samples = t.tensor(task.samples(k_shot)).to( device) # sample from this problem loss, _ = metrics.comp_loss(*samples, model) model.zero_grad() loss.backward() for i, grad in enumerate( model.gradients()): # Now accumulate the gradient p = ps[i] task_grads[p].append(grad) meta_grads[p] += t.prod(t.cat(list( map(lambda d: d.unsqueeze(dim=-1), task_grads[p])), dim=-1), dim=-1) # theta_prime = copy.deepcopy(model.state_dict()) model.load_state_dict(theta) for p in ps: p.grad = t.tensor( (meta_grads[p] / task_batch_n).detach()).to(device) model.meta_step(lr=G.beta) with t.no_grad(): if test_fn is not None: test_fn(0, model, epoch=ep_ind) _loss, _ = metrics.comp_loss(*proper, model) # it is very easy to use the wrong loss here. logger.log_keyvalue(ep_ind, '0-grad-loss', _loss.item(), silent=True) # save model weight if G.save_interval and ep_ind % G.save_interval == 0: logger.log_module(ep_ind, **{type(model).__name__: model})
def maml(model=None, test_fn=None): from playground.maml.maml_torch.tasks import Sine model = model or (StandardMLP(1, 1) if G.debug else FunctionalMLP(1, 1)) meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta) mse = t.nn.MSELoss() M.tic('start') M.tic('epoch') for ep_ind in range(G.n_epochs): dt = M.split('epoch', silent=True) dt_ = M.toc('start', silent=True) print(f"epoch {ep_ind} @ {dt:.4f}sec/ep, {dt_:.1f} sec from start") original_ps = OrderedDict(model.named_parameters()) tasks = [Sine() for _ in range(G.task_batch_n)] for task_ind, task in enumerate(tasks): if task_ind != 0: model.params.update(original_ps) if G.test_mode: _gradient = original_ps['bias_var'].grad if task_ind == 0: assert _gradient is None or _gradient.sum().item( ) == 0, f"{_gradient} is not zero or None, epoch {ep_ind}." else: assert _gradient.sum().item( ) != 0, f"{_gradient} should be non-zero" assert ( original_ps['bias_var'] == model.params['bias_var'] ).all().item() == 1, 'the two parameters should be the same' xs, labels = t.DoubleTensor(task.samples(G.k_shot)) _silent = task_ind != 0 for grad_ind in range(G.n_gradient_steps): if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task=task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, silent=_silent, h0=ht) dps = t.autograd.grad(loss, model.parameters(), create_graph=True, retain_graph=True) # 1. update parameters, use updated theta'. # 2. run forward, get direct gradient to update the network for (name, p), dp in zip(model.named_parameters(), dps): model.params[name] = p - G.alpha * dp grad_ind = G.n_gradient_steps # meta gradient if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task=task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, silent=_silent, h0=ht) meta_dps = t.autograd.grad(loss, original_ps.values()) with t.no_grad(): for (name, p), meta_dp in zip(original_ps.items(), meta_dps): p.grad = (0 if p.grad is None else p.grad) + meta_dp # normalize the gradient. with t.no_grad(): for (name, p) in original_ps.items(): p.grad /= G.task_batch_n model.params.update(original_ps) meta_optimizer.step() meta_optimizer.zero_grad() if G.save_interval and ep_ind % G.save_interval == 0: logger.log_module(ep_ind, **{type(model).__name__: model}) logger.flush()
def train(): from moleskin import moleskin as M M.tic('Full Run') if G.model == "lenet": model = Conv2d() elif G.model == 'mlp': model = Mlp() else: raise NotImplementedError('only lenet and mlp are allowed') model.train() print(model) G.log_prefix = f"mnist_{type(model).__name__}" logger.configure(log_directory=G.log_dir, prefix=G.log_prefix) logger.log_params(G=vars(G), Model=dict(architecture=str(model))) from torchvision import datasets, transforms trans = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (1.0, ))]) train_set = datasets.MNIST(root=G.data_dir, train=True, transform=trans, download=True) test_set = datasets.MNIST(root=G.data_dir, train=False, transform=trans, download=True) train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=G.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=G.batch_size, shuffle=False) celoss = nn.CrossEntropyLoss() adam = optim.SGD(model.parameters(), lr=G.learning_rate, momentum=0.9) for epoch in range(G.n_epochs): for it, (x, target) in enumerate(train_loader): adam.zero_grad() ys = model(x) loss = celoss(ys, target) loss.backward() adam.step() if it % G.test_interval == 0: with h.Eval(model), torch.no_grad(): accuracy = h.Average() for x, label in test_loader: acc = h.cast( h.one_hot_to_int(model(x).detach()) == label, float).sum() / len(x) accuracy.add(acc.detach().numpy()) logger.log(float(epoch) + it / len(train_loader), accuracy=accuracy.value) M.split("epoch") # logger.log(epoch, it=it, loss=loss.detach().numpy()) M.toc('Full Run')
def reptile(model=None, test_fn=None): from playground.maml.maml_torch.tasks import Sine model = model or FunctionalMLP(1, 1) meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta) mse = t.nn.MSELoss() M.tic('epoch') for ep_ind in range(G.n_epochs): M.split('epoch') original_ps = OrderedDict(model.named_parameters()) tasks = [Sine() for _ in range(G.task_batch_n)] for task_ind, task in enumerate(tasks): if task_ind != 0: model.params.update(original_ps) xs, labels = t.DoubleTensor(task.samples(G.k_shot)) _silent = task_ind != 0 for grad_ind in range(G.n_gradient_steps): if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) with t.no_grad(): logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable( test_fn ) and ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, h0=ht, silent=_silent) dps = t.autograd.grad(loss, model.parameters()) with t.no_grad(): for (name, p), dp in zip(model.named_parameters(), dps): model.params[name] = p - G.alpha * dp model.params[name].requires_grad = True grad_ind = G.n_gradient_steps with t.no_grad(): # domain adaptation if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, h0=ht, silent=_silent) # Compute REPTILE 1st-order gradient with t.no_grad(): for name, p in original_ps.items(): # let's do the division at the end. p.grad = (0 if p.grad is None else p.grad) + ( p - model.params[name]) # / G.task_batch_n with t.no_grad(): for name, p in original_ps.items(): # let's do the division at the end. p.grad /= G.task_batch_n model.params.update(original_ps) meta_optimizer.step() meta_optimizer.zero_grad() if G.save_interval and ep_ind % G.save_interval == 0: logger.log_module(ep_ind, **{type(model).__name__: model}) logger.flush()
def train(): from linear_schedule import Linear ledger = defaultdict(lambda: MovingAverage(Reporting.reward_average)) M.config(file=os.path.join(RUN.log_directory, RUN.log_file)) M.diff() with U.make_session( RUN.num_cpu), Logger(RUN.log_directory) as logger, contextify( gym.make(G.env_name)) as env: env = ScaledFloatFrame(wrap_dqn(env)) if G.seed is not None: env.seed(G.seed) logger.log_params(G=vars(G), RUN=vars(RUN), Reporting=vars(Reporting)) inputs = TrainInputs(action_space=env.action_space, observation_space=env.observation_space) trainer = QTrainer(inputs=inputs, action_space=env.action_space, observation_space=env.observation_space) if G.prioritized_replay: replay_buffer = PrioritizedReplayBuffer(size=G.buffer_size, alpha=G.alpha) else: replay_buffer = ReplayBuffer(size=G.buffer_size) class schedules: # note: it is important to have this start from the begining. eps = Linear(G.n_timesteps * G.exploration_fraction, 1, G.final_eps) if G.prioritized_replay: beta = Linear(G.n_timesteps - G.learning_start, G.beta_start, G.beta_end) U.initialize() trainer.update_target() x = np.array(env.reset()) ep_ind = 0 M.tic('episode') for t_step in range(G.n_timesteps): # schedules eps = 0 if G.param_noise else schedules.eps[t_step] if G.prioritized_replay: beta = schedules.beta[t_step - G.learning_start] x0 = x M.tic('sample', silent=True) (action, *_), action_q, q = trainer.runner.act([x], eps) x, rew, done, info = env.step(action) ledger['action_q_value'].append(action_q.max()) ledger['action_q_value/mean'].append(action_q.mean()) ledger['action_q_value/var'].append(action_q.var()) ledger['q_value'].append(q.max()) ledger['q_value/mean'].append(q.mean()) ledger['q_value/var'].append(q.var()) ledger['timing/sample'].append(M.toc('sample', silent=True)) # note: adding sample to the buffer is identical between the prioritized and the standard replay strategy. replay_buffer.add(s0=x0, action=action, reward=rew, s1=x, done=float(done)) logger.log( t_step, { 'q_value': ledger['q_value'].latest, 'q_value/mean': ledger['q_value/mean'].latest, 'q_value/var': ledger['q_value/var'].latest, 'q_value/action': ledger['action_q_value'].latest, 'q_value/action/mean': ledger['action_q_value/mean'].latest, 'q_value/action/var': ledger['action_q_value/var'].latest }, action=action, eps=eps, silent=True) if G.prioritized_replay: logger.log(t_step, beta=beta, silent=True) if done: ledger['timing/episode'].append(M.split('episode', silent=True)) ep_ind += 1 x = np.array(env.reset()) ledger['rewards'].append(info['total_reward']) silent = (ep_ind % Reporting.print_interval != 0) logger.log(t_step, timestep=t_step, episode=green(ep_ind), total_reward=ledger['rewards'].latest, episode_length=info['timesteps'], silent=silent) logger.log(t_step, { 'total_reward/mean': yellow(ledger['rewards'].mean, lambda v: f"{v:.1f}"), 'total_reward/max': yellow(ledger['rewards'].max, lambda v: f"{v:.1f}"), "time_spent_exploring": default(eps, percent), "timing/episode": green(ledger['timing/episode'].latest, sec), "timing/episode/mean": green(ledger['timing/episode'].mean, sec), }, silent=silent) try: logger.log(t_step, { "timing/sample": default(ledger['timing/sample'].latest, sec), "timing/sample/mean": default(ledger['timing/sample'].mean, sec), "timing/train": default(ledger['timing/train'].latest, sec), "timing/train/mean": green(ledger['timing/train'].mean, sec), "timing/log_histogram": default(ledger['timing/log_histogram'].latest, sec), "timing/log_histogram/mean": default(ledger['timing/log_histogram'].mean, sec) }, silent=silent) if G.prioritized_replay: logger.log(t_step, { "timing/update_priorities": default(ledger['timing/update_priorities'].latest, sec), "timing/update_priorities/mean": default(ledger['timing/update_priorities'].mean, sec) }, silent=silent) except Exception as e: pass if G.prioritized_replay: logger.log( t_step, {"replay_beta": default(beta, lambda v: f"{v:.2f}")}, silent=silent) # note: learn here. if t_step >= G.learning_start and t_step % G.learn_interval == 0: if G.prioritized_replay: experiences, weights, indices = replay_buffer.sample( G.replay_batch_size, beta) logger.log_histogram(t_step, weights=weights) else: experiences, weights = replay_buffer.sample( G.replay_batch_size), None M.tic('train', silent=True) x0s, actions, rewards, x1s, dones = zip(*experiences) td_error_val, loss_val = trainer.train(s0s=x0s, actions=actions, rewards=rewards, s1s=x1s, dones=dones, sample_weights=weights) ledger['timing/train'].append(M.toc('train', silent=True)) M.tic('log_histogram', silent=True) logger.log_histogram(t_step, td_error=td_error_val) ledger['timing/log_histogram'].append( M.toc('log_histogram', silent=True)) if G.prioritized_replay: M.tic('update_priorities', silent=True) new_priorities = np.abs(td_error_val) + eps replay_buffer.update_priorities(indices, new_priorities) ledger['timing/update_priorities'].append( M.toc('update_priorities', silent=True)) if t_step % G.target_network_update_interval == 0: trainer.update_target() if t_step % Reporting.checkpoint_interval == 0: U.save_state(os.path.join(RUN.log_directory, RUN.checkpoint))