def sgd_baseline(lr=0.001): from playground.maml.maml_torch.tasks import Sine task = Sine() model = StandardMLP(1, 1) if G.debug else FunctionalMLP(1, 1) adam = t.optim.Adam([p for p in model.parameters()], lr=lr) mse = t.nn.MSELoss() for ep_ind in range(1000): xs, labels = h.const(task.proper()) ys = model(xs.unsqueeze(-1)) loss = mse(ys, labels.unsqueeze(-1)) logger.log(ep_ind, loss=loss.item(), silent=ep_ind % 50) adam.zero_grad() loss.backward() adam.step() logger.flush()
def run_e_maml(_G=None): import baselines.common.tf_util as U if _G is not None: config.G.update(_G) for k, v in [ *vars(config.RUN).items(), *vars(config.G).items(), *vars(config.Reporting).items(), *vars(config.DEBUG).items() ]: comet_logger.log_parameter(k, v) # todo: let's take the control of the log director away from the train script. It should all be set from outside. logger.configure(log_directory=config.RUN.log_dir, prefix=config.RUN.log_prefix) logger.log_params(RUN=vars(config.RUN), G=vars(config.G), Reporting=vars(config.Reporting), DEBUG=vars(config.DEBUG)) logger.log_file(__file__) tasks = MetaRLTasks(env_name=config.G.env_name, batch_size=config.G.n_parallel_envs, start_seed=config.G.start_seed, log_directory=(config.RUN.log_directory + "/{seed}") if config.G.render else None, max_steps=config.G.env_max_timesteps) # sess_config = tf.ConfigProto(log_device_placement=config.Reporting.log_device_placement) # with tf.Session(config=sess_config), tf.device('/gpu:0'), tasks: graph = tf.Graph() with graph.as_default(), U.make_session(num_cpu=config.G.n_cpu), tasks: maml = E_MAML(ob_space=tasks.envs.observation_space, act_space=tasks.envs.action_space) comet_logger.set_model_graph(tf.get_default_graph()) # writer = tf.summary.FileWriter(logdir='/opt/project/debug-graph', graph=graph) # writer.flush() # exit() trainer = Trainer() U.initialize() trainer.train(tasks=tasks, maml=maml) logger.flush() tf.reset_default_graph()
def thunk(*args, **kwargs): import traceback from ml_logger import logger assert not (args and ARGS), \ f"can not use position argument at both thunk creation as well as run.\n" \ f"_args: {args}\n" \ f"ARGS: {ARGS}\n" logger.configure(root_dir=RUN.server, prefix=PREFIX, register_experiment=False, max_workers=10) logger.log_params(host=dict(hostname=logger.hostname), run=dict(status="running", startTime=logger.now(), job_id=logger.job_id)) import time try: _KWARGS = {**KWARGS} _KWARGS.update(**kwargs) results = fn(*(args or ARGS), **_KWARGS) logger.log_line("========== execution is complete ==========") logger.log_params( run=dict(status="completed", completeTime=logger.now())) logger.flush() time.sleep(3) except Exception as e: tb = traceback.format_exc() with logger.SyncContext( ): # Make sure uploaded finished before termination. logger.print(tb, color="red") logger.log_text(tb, filename="traceback.err") logger.log_params( run=dict(status="error", exitTime=logger.now())) logger.flush() time.sleep(3) raise e return results
def _(*args, **kwargs): import traceback from ml_logger import logger assert not (args and ARGS), f"can not use position argument at both thunk creation as well as " \ f"run.\n_args: {args}\nARGS: {ARGS}" logger.configure(log_directory=RUN.server, prefix=PREFIX, register_experiment=False, max_workers=10) logger.log_params(host=dict(hostname=logger.hostname), run=dict(status="running", startTime=logger.now())) try: _KWARGS = KWARGS.copy() _KWARGS.update(kwargs) fn(*(args or ARGS), **_KWARGS) logger.log_line("========= execution is complete ==========") logger.log_params( run=dict(status="completed", completeTime=logger.now())) except Exception as e: import time time.sleep(1) tb = traceback.format_exc() with logger.SyncContext( ): # Make sure uploaded finished before termination. logger.log_text(tb, filename="traceback.err") logger.log_params( run=dict(status="error", exitTime=logger.now())) logger.log_line(tb) logger.flush() time.sleep(30) raise e import time time.sleep(30)
def train_maml(*, n_tasks: int, tasks: MetaRLTasks, maml: E_MAML): if not G.inner_alg.startswith("BC"): path_gen = path_gen_fn(env=tasks.envs, policy=maml.runner.policy, start_reset=G.reset_on_start) next(path_gen) meta_path_gen = path_gen_fn(env=tasks.envs, policy=maml.meta_runner.policy, start_reset=G.reset_on_start) next(meta_path_gen) if G.load_from_checkpoint: # todo: add variable to checkpoint # todo: set the epoch_ind starting point here. logger.load_variables(G.load_from_checkpoint) if G.meta_sgd: assert maml.alpha is not None, "Coding Mistake if meta_sgd is trueful but maml.alpha is None." max_episode_length = tasks.spec.max_episode_steps sess = tf.get_default_session() epoch_ind, prefix = G.epoch_init - 1, "" while epoch_ind < G.epoch_init + G.n_epochs: logger.flush() logger.split() is_bc_test = (prefix != "test/" and G.eval_interval and epoch_ind % G.eval_interval == 0) prefix = "test/" if is_bc_test else "" epoch_ind += 0 if is_bc_test else 1 if G.meta_sgd: alpha_lr = sess.run(maml.alpha) # only used in the runner. logger.log(metrics={f"alpha_{i}/{stem(t.name, 2)}": a for i, a_ in enumerate(alpha_lr) for t, a in zip(maml.runner.trainables, a_)}, silent=True) else: alpha_lr = G.alpha.send(epoch_ind) if isinstance(G.alpha, Schedule) else np.array(G.alpha) logger.log(alpha=metrify(alpha_lr), epoch=epoch_ind, silent=True) beta_lr = G.beta.send(epoch_ind) if isinstance(G.beta, Schedule) else np.array(G.beta) clip_range = G.clip_range.send(epoch_ind) if isinstance(G.clip_range, Schedule) else np.array(G.clip_range) logger.log(beta=metrify(beta_lr), clip_range=metrify(clip_range), epoch=epoch_ind, silent=True) batch_timesteps = G.batch_timesteps.send(epoch_ind) \ if isinstance(G.batch_timesteps, Schedule) else G.batch_timesteps # Compute updates for each task in the batch # 0. save value of variables # 1. sample # 2. gradient descent # 3. repeat step 1., 2. until all gradient steps are exhausted. batch_data = defaultdict(list) maml.save_weight_cache() load_ops = [] if DEBUG.no_weight_reset else [maml.cache.load] if G.checkpoint_interval and epoch_ind % G.checkpoint_interval == 0 \ and not is_bc_test and epoch_ind >= G.start_checkpoint_after_epoch: cp_path = f"checkpoints/variables_{epoch_ind:04d}.pkl" logger.log_line(f'saving checkpoint {cp_path}') # note: of course I don't know that are all of the trainables at the moment. logger.save_variables(tf.trainable_variables(), path=cp_path) feed_dict = {} for task_ind in range(n_tasks if is_bc_test else G.n_tasks): graph_branch = maml.graphs[0] if G.n_graphs == 1 else maml.graphs[task_ind] if G.n_graphs == 1: gradient_sum_op = maml.gradient_sum.set_op if task_ind == 0 else maml.gradient_sum.add_op print(f"task_ind {task_ind}...") if not DEBUG.no_task_resample: if not is_bc_test: print(f'L250: sampling task') tasks.sample() elif task_ind < n_tasks: task_spec = dict(index=task_ind % n_tasks) print(f'L254: sampling task {task_spec}') tasks.sample(**task_spec) else: raise RuntimeError('should never hit here.') for k in range(G.n_grad_steps + 1): # 0 - 10 <== last one being the maml policy. _is_new = False # for imitation inner loss, we still sample trajectory for evaluation purposes, but # replace it with the demonstration data for learning if k < G.n_grad_steps: if G.inner_alg.startswith("BC"): p = p if G.single_sampling and k > 0 else \ bc.sample_demonstration_data(tasks.task_spec, key=("eval" if is_bc_test else None)) else: p, _is_new = path_gen.send(batch_timesteps), True elif k == G.n_grad_steps: if G.meta_alg.startswith("BC"): # note: use meta bc samples. p = bc.sample_demonstration_data(tasks.task_spec, key="meta") else: p, _is_new = meta_path_gen.send(batch_timesteps), True else: raise Exception('Implementation error. Should never reach this line.') if k in G.eval_grad_steps: _ = path_gen if k < G.n_grad_steps else meta_path_gen p_eval = p if _is_new else _.send(G.eval_timesteps) # reporting on new trajectory samples avg_r = p_eval['ep_info']['reward'] if G.normalize_env else np.mean(p_eval['rewards']) episode_r = avg_r * max_episode_length # default horizon for HalfCheetah if episode_r < G.term_reward_threshold: # todo: make this batch-based instead of on single episode logger.log_line("episode reward is too low: ", episode_r, "terminating training.", flush=True) raise RuntimeError('AVERAGE REWARD TOO LOW. Terminating the experiment.') batch_data[prefix + f"grad_{k}_step_reward"].append(avg_r if Reporting.report_mean else episode_r) if k in G.eval_grad_steps: logger.log_key_value(prefix + f"task_{task_ind}_grad_{k}_reward", episode_r, silent=True) _p = {k: v for k, v in p.items() if k != "ep_info"} if k < G.n_grad_steps: # note: under meta-SGD mode, the runner needs the k^th learning rate. _lr = alpha_lr[k] if G.meta_sgd else alpha_lr # clip_range is not used in BC mode. but still passed in. runner_feed_dict = \ path_to_feed_dict(inputs=maml.runner.inputs, paths=_p, lr=_lr, baseline=G.baseline, gamma=G.gamma, use_gae=G.use_gae, lam=G.lam, horizon=max_episode_length, clip_range=clip_range) # todo: optimize `maml.meta_runner` if k >= G.n_grad_steps. loss, *_, __ = maml.runner.optim.run_optimize(feed_dict=runner_feed_dict) runner_feed_dict.clear() for key, value in zip(maml.runner.model.reports.keys(), [loss, *_]): batch_data[prefix + f"grad_{k}_step_{key}"].append(value) logger.log_key_value(prefix + f"task_{task_ind}_grad_{k}_{key}", value, silent=True) if loss > G.term_loss_threshold: # todo: make this batch-based instead of on single episode logger.log_line(prefix + "episode loss blew up:", loss, "terminating training.", flush=True) raise RuntimeError('loss is TOO HIGH. Terminating the experiment.') # done: has bug when using fixed learning rate. Needs the learning rate as input. feed_dict.update( # do NOT pass in the learning rate because the graph already includes those. path_to_feed_dict(inputs=graph_branch.workers[k].inputs, paths=_p, lr=None if G.meta_sgd else alpha_lr, # but do with fixed alpha horizon=max_episode_length, baseline=G.baseline, gamma=G.gamma, use_gae=G.use_gae, lam=G.lam, clip_range=clip_range)) elif k == G.n_grad_steps: yield_keys = dict( movie=epoch_ind >= G.start_movie_after_epoch and epoch_ind % G.record_movie_interval == 0, eval=is_bc_test ) if np.fromiter(yield_keys.values(), bool).any(): yield yield_keys, epoch_ind, tasks.task_spec if is_bc_test: if load_ops: # we need to reset the weights. Otherwise the world would be on fire. tf.get_default_session().run(load_ops) continue # do NOT meta learn from test samples. # we don't treat the meta_input the same way even though we could. This is more clear to read. # note: feed in the learning rate only later. feed_dict.update( # do NOT need learning rate path_to_feed_dict(inputs=graph_branch.meta.inputs, paths=_p, horizon=max_episode_length, baseline=G.baseline, gamma=G.gamma, use_gae=G.use_gae, lam=G.lam, clip_range=clip_range)) if G.n_graphs == 1: # load from checkpoint before computing the meta gradient\nrun gradient sum operation if load_ops: tf.get_default_session().run(load_ops) # note: meta reporting should be run here. Not supported for simplicity. (need to reduce across # note: tasks, and can not be done outside individual task graphs. if G.meta_sgd is None: # note: copied from train_supervised_maml, not tested feed_dict[maml.alpha] = alpha_lr tf.get_default_session().run(gradient_sum_op, feed_dict) feed_dict.clear() if load_ops: tf.get_default_session().run(load_ops) if is_bc_test: continue # do NOT meta learn from test samples. # note: copied from train_supervised_maml, not tested if G.meta_sgd is None: feed_dict[maml.alpha] = alpha_lr if G.n_graphs == 1: assert G.meta_n_grad_steps == 1, "ERROR: Can only run 1 meta gradient step with a single graph." # note: remove meta reporting b/c meta report should be in each task in this case. tf.get_default_session().run(maml.meta_update_ops[0], {maml.beta: beta_lr}) else: assert feed_dict, "ERROR: It is likely that you jumped here from L:178." feed_dict[maml.beta] = beta_lr for i in range(G.meta_n_grad_steps): update_op = maml.meta_update_ops[0 if G.reuse_meta_optimizer else i] *reports, _ = tf.get_default_session().run(maml.meta_reporting + [update_op], feed_dict) if i not in (0, G.meta_n_grad_steps - 1): continue for key, v in zip(maml.meta_reporting_keys, reports): logger.log_key_value(prefix + f"grad_{G.n_grad_steps + i}_step_{key}", v, silent=True) feed_dict.clear() tf.get_default_session().run(maml.cache.save) # Now compute the meta gradients. # note: runner shares variables with the MAML graph. Reload from state_dict # note: if max_grad_step is the same as n_grad_steps then no need here. dt = logger.split() logger.log_line('Timer Starts...' if dt is None else f'{dt:0.2f} sec/epoch') logger.log(dt_epoch=dt or np.nan, epoch=epoch_ind) for key, arr in batch_data.items(): reduced = np.array(arr).mean() logger.log_key_value(key, reduced) logger.flush()
def train_supervised_maml(*, k_tasks=1, maml: E_MAML): # env used for evaluation purposes only. if G.meta_sgd: assert maml.alpha is not None, "Coding Mistake if meta_sgd is trueful but maml.alpha is None." assert G.n_tasks >= k_tasks, f"Is this intended? You probably want to have " \ f"meta-batch({G.n_tasks}) >= k_tasks({k_tasks})." sess = tf.get_default_session() epoch_ind, pref = -1, "" while epoch_ind < G.n_epochs: # for epoch_ind in range(G.n_epochs + 1): logger.flush() logger.split() is_bc_test = (pref != "test/" and G.eval_interval and epoch_ind % G.eval_interval == 0) pref = "test/" if is_bc_test else "" epoch_ind += 0 if is_bc_test else 1 if G.meta_sgd: alpha_lr = sess.run(maml.alpha) # only used in the runner. logger.log(metrics={f"alpha_{i}/{stem(t.name, 2)}": a for i, a_ in enumerate(alpha_lr) for t, a in zip(maml.runner.trainables, a_)}, silent=True) else: alpha_lr = G.alpha.send(epoch_ind) if isinstance(G.alpha, Schedule) else np.array(G.alpha) logger.log(alpha=metrify(alpha_lr), epoch=epoch_ind, silent=True) beta_lr = G.beta.send(epoch_ind) if isinstance(G.beta, Schedule) else np.array(G.beta) logger.log(beta=metrify(beta_lr), epoch=epoch_ind, silent=True) if G.checkpoint_interval and epoch_ind % G.checkpoint_interval == 0: yield "pre-update-checkpoint", epoch_ind # Compute updates for each task in the batch # 0. save value of variables # 1. sample # 2. gradient descent # 3. repeat step 1., 2. until all gradient steps are exhausted. batch_data = defaultdict(list) maml.save_weight_cache() load_ops = [] if DEBUG.no_weight_reset else [maml.cache.load] feed_dict = {} for task_ind in range(k_tasks if is_bc_test else G.n_tasks): graph_branch = maml.graphs[0] if G.n_graphs == 1 else maml.graphs[task_ind] if G.n_graphs == 1: gradient_sum_op = maml.gradient_sum.set_op if task_ind == 0 else maml.gradient_sum.add_op """ In BC mode, we don't have an environment. The sampling is handled here then fed to the sampler. > task_spec = dict(index=0) Here we make the testing more efficient. """ if not DEBUG.no_task_resample: if not is_bc_test: task_spec = dict(index=np.random.randint(0, k_tasks)) elif task_ind < k_tasks: task_spec = dict(index=task_ind % k_tasks) else: raise RuntimeError('should never hit here.') for k in range(G.n_grad_steps + 1): # 0 - 10 <== last one being the maml policy. # for imitation inner loss, we still sample trajectory for evaluation purposes, but # replace it with the demonstration data for learning if k < G.n_grad_steps: p = p if G.single_sampling and k > 0 else \ bc.sample_demonstration_data(task_spec, key=("eval" if is_bc_test else None)) elif k == G.n_grad_steps: # note: use meta bc samples. p = bc.sample_demonstration_data(task_spec, key="meta") else: raise Exception('Implementation error. Should never reach this line.') _p = {k: v for k, v in p.items() if k != "ep_info"} if k < G.n_grad_steps: # note: under meta-SGD mode, the runner needs the k^th learning rate. _lr = alpha_lr[k] if G.meta_sgd else alpha_lr runner_feed_dict = \ path_to_feed_dict(inputs=maml.runner.inputs, paths=_p, lr=_lr) # todo: optimize `maml.meta_runner` if k >= G.n_grad_steps. loss, *_, __ = maml.runner.optim.run_optimize(feed_dict=runner_feed_dict) runner_feed_dict.clear() for key, value in zip(maml.runner.model.reports.keys(), [loss, *_]): batch_data[pref + f"grad_{k}_step_{key}"].append(value) logger.log_key_value(pref + f"task_{task_ind}_grad_{k}_{key}", value, silent=True) if loss > G.term_loss_threshold: # todo: make this batch-based instead of on single episode err = pref + "episode loss blew up:", loss, "terminating training." logger.log_line(colored(err, "red"), flush=True) raise RuntimeError('loss is TOO HIGH. Terminating the experiment.') # fixit: has bug when using fixed learning rate. Still needs to get learning rate from placeholder feed_dict.update(path_to_feed_dict(inputs=graph_branch.workers[k].inputs, paths=_p)) elif k == G.n_grad_steps: yield_keys = dict( movie=G.record_movie_interval and epoch_ind >= G.start_movie_after_epoch and epoch_ind % G.record_movie_interval == 0, eval=is_bc_test ) if np.fromiter(yield_keys.values(), bool).any(): yield yield_keys, epoch_ind, task_spec if is_bc_test: if load_ops: tf.get_default_session().run(load_ops) continue # do NOT meta learn from test samples. # we don't treat the meta_input the same way even though we could. This is more clear to read. # note: feed in the learning rate only later. feed_dict.update(path_to_feed_dict(inputs=graph_branch.meta.inputs, paths=_p)) if G.n_graphs == 1: # load from checkpoint before computing the meta gradient\nrun gradient sum operation if load_ops: tf.get_default_session().run(load_ops) # note: meta reporting should be run here. Not supported for simplicity. (need to reduce across # note: tasks, and can not be done outside individual task graphs. if G.meta_sgd is None: feed_dict[maml.alpha] = alpha_lr tf.get_default_session().run(gradient_sum_op, feed_dict) feed_dict.clear() if load_ops: tf.get_default_session().run(load_ops) if is_bc_test: continue # do NOT meta learn from test samples. if G.meta_sgd is None: feed_dict[maml.alpha] = alpha_lr if G.n_graphs == 1: assert G.meta_n_grad_steps == 1, "ERROR: Can only run 1 meta gradient step with a single graph." # note: remove meta reporting b/c meta report should be in each task in this case. tf.get_default_session().run(maml.meta_update_ops[0], {maml.beta: beta_lr}) else: assert feed_dict, "ERROR: It is likely that you jumped here from L:178." feed_dict[maml.beta] = beta_lr for i in range(G.meta_n_grad_steps): update_op = maml.meta_update_ops[0 if G.reuse_meta_optimizer else i] *reports, _ = tf.get_default_session().run(maml.meta_reporting + [update_op], feed_dict) if i not in (0, G.meta_n_grad_steps - 1): continue for key, v in zip(maml.meta_reporting_keys, reports): logger.log_key_value(pref + f"grad_{G.n_grad_steps + i}_step_{key}", v, silent=True) feed_dict.clear() tf.get_default_session().run(maml.cache.save) # Now compute the meta gradients. # note: runner shares variables with the MAML graph. Reload from state_dict # note: if max_grad_step is the same as n_grad_steps then no need here. dt = logger.split() logger.log_line('Timer Starts...' if dt is None else f'{dt:0.2f} sec/epoch') logger.log(dt_epoch=dt or np.nan, epoch=epoch_ind) for key, arr in batch_data.items(): reduced = np.array(arr).mean() logger.log_key_value(key, reduced)
def _run_single_task(self, i, task): start_time = time.time() try: task_hash = _hash_task_dict( task) # generate SHA256 hash of task dict as identifier # skip task if it has already been completed if task_hash in self.gof_single_res_collection.keys(): logger.log("Task {:<1} {:<63} {:<10} {:<1} {:<1} {:<1}".format( i + 1, "has already been completed:", "Estimator:", task['estimator_name'], " Simulator: ", task["simulator_name"])) return None # run task when it has not been completed else: logger.log("Task {:<1} {:<63} {:<10} {:<1} {:<1} {:<1}".format( i + 1, "running:", "Estimator:", task['estimator_name'], " Simulator: ", task["simulator_name"])) tf.reset_default_graph() ''' build simulator and estimator model given the specified configurations ''' simulator = globals()[task['simulator_name']]( **task['simulator_config']) t = time.time() estimator = globals()[task['estimator_name']]( task['task_name'], simulator.ndim_x, simulator.ndim_y, **task['estimator_config']) time_to_initialize = time.time() - t # if desired hide gpu devices if not self.use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ''' train the model ''' gof = GoodnessOfFit(estimator=estimator, probabilistic_model=simulator, X=task['X'], Y=task['Y'], n_observations=task['n_obs'], n_mc_samples=task['n_mc_samples'], x_cond=task['x_cond'], task_name=task['task_name'], tail_measures=self.tail_measures) t = time.time() gof.fit_estimator(print_fit_result=True) time_to_fit = time.time() - t if self.dump_models: logger.dump_pkl(data=gof.estimator, path="model_dumps/{}.pkl".format( task['task_name'])) logger.dump_pkl(data=gof.probabilistic_model, path="model_dumps/{}.pkl".format( task['task_name'] + "_simulator")) ''' perform tests with the fitted model ''' t = time.time() gof_results = gof.compute_results() time_to_evaluate = time.time() - t gof_results.task_name = task['task_name'] gof_results.hash = task_hash logger.log_pkl(data=(task_hash, gof_results), path=RESULTS_FILE) logger.flush(file_name=RESULTS_FILE) del gof_results task_duration = time.time() - start_time logger.log( "Finished task {:<1} in {:<1.4f} {:<43} {:<10} {:<1} {:<1} {:<2} | {:<1} {:<1.2f} {:<1} {:<1.2f} {:<1} {:<1.2f}" .format(i + 1, task_duration, "sec:", "Estimator:", task['estimator_name'], " Simulator: ", task["simulator_name"], "t_init:", time_to_initialize, "t_fit:", time_to_fit, "t_eval:", time_to_evaluate)) except Exception as e: logger.log("error in task: ", str(i + 1)) logger.log(str(e)) traceback.print_exc()
def test_configuration(log_dir): logger.configure(log_dir, prefix='main_test_script', color='green') logger.log("This is a unittest") logger.log("Some stats", reward=0.05, kl=0.001) logger.flush()
### First configure the logger to log to a direction (or a server) logger.configure('/tmp/ml-logger-debug') # outputs ~> # logging data to /tmp/ml-logger-debug # We can log individual keys for i in range(1): logger.log(metrics={ 'some_val/smooth': 10, 'status': f"step ({i})" }, reward=20, timestep=i) ### flush the data, otherwise the value would be overwritten with new values in the next iteration. logger.flush() # outputs ~> # ╒════════════════════╤════════════════════════════╕ # │ reward │ 20 │ # ├────────────────────┼────────────────────────────┤ # │ timestep │ 0 │ # ├────────────────────┼────────────────────────────┤ # │ some val/smooth │ 10 │ # ├────────────────────┼────────────────────────────┤ # │ status │ step (0) │ # ├────────────────────┼────────────────────────────┤ # │ timestamp │'2018-11-04T11:37:03.324824'│ # ╘════════════════════╧════════════════════════════╛ for i in range(100): logger.store_metrics(metrics={'some_val/smooth': 10}, some=20, timestep=i)
def maml(model=None, test_fn=None): from playground.maml.maml_torch.tasks import Sine model = model or (StandardMLP(1, 1) if G.debug else FunctionalMLP(1, 1)) meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta) mse = t.nn.MSELoss() M.tic('start') M.tic('epoch') for ep_ind in range(G.n_epochs): dt = M.split('epoch', silent=True) dt_ = M.toc('start', silent=True) print(f"epoch {ep_ind} @ {dt:.4f}sec/ep, {dt_:.1f} sec from start") original_ps = OrderedDict(model.named_parameters()) tasks = [Sine() for _ in range(G.task_batch_n)] for task_ind, task in enumerate(tasks): if task_ind != 0: model.params.update(original_ps) if G.test_mode: _gradient = original_ps['bias_var'].grad if task_ind == 0: assert _gradient is None or _gradient.sum().item( ) == 0, f"{_gradient} is not zero or None, epoch {ep_ind}." else: assert _gradient.sum().item( ) != 0, f"{_gradient} should be non-zero" assert ( original_ps['bias_var'] == model.params['bias_var'] ).all().item() == 1, 'the two parameters should be the same' xs, labels = t.DoubleTensor(task.samples(G.k_shot)) _silent = task_ind != 0 for grad_ind in range(G.n_gradient_steps): if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task=task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, silent=_silent, h0=ht) dps = t.autograd.grad(loss, model.parameters(), create_graph=True, retain_graph=True) # 1. update parameters, use updated theta'. # 2. run forward, get direct gradient to update the network for (name, p), dp in zip(model.named_parameters(), dps): model.params[name] = p - G.alpha * dp grad_ind = G.n_gradient_steps # meta gradient if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size[5, batch_n: 1, 1] else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task=task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, silent=_silent, h0=ht) meta_dps = t.autograd.grad(loss, original_ps.values()) with t.no_grad(): for (name, p), meta_dp in zip(original_ps.items(), meta_dps): p.grad = (0 if p.grad is None else p.grad) + meta_dp # normalize the gradient. with t.no_grad(): for (name, p) in original_ps.items(): p.grad /= G.task_batch_n model.params.update(original_ps) meta_optimizer.step() meta_optimizer.zero_grad() if G.save_interval and ep_ind % G.save_interval == 0: logger.log_module(ep_ind, **{type(model).__name__: model}) logger.flush()
def reptile(model=None, test_fn=None): from playground.maml.maml_torch.tasks import Sine model = model or FunctionalMLP(1, 1) meta_optimizer = t.optim.Adam(model.parameters(), lr=G.beta) mse = t.nn.MSELoss() M.tic('epoch') for ep_ind in range(G.n_epochs): M.split('epoch') original_ps = OrderedDict(model.named_parameters()) tasks = [Sine() for _ in range(G.task_batch_n)] for task_ind, task in enumerate(tasks): if task_ind != 0: model.params.update(original_ps) xs, labels = t.DoubleTensor(task.samples(G.k_shot)) _silent = task_ind != 0 for grad_ind in range(G.n_gradient_steps): if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) with t.no_grad(): logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable( test_fn ) and ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, h0=ht, silent=_silent) dps = t.autograd.grad(loss, model.parameters()) with t.no_grad(): for (name, p), dp in zip(model.named_parameters(), dps): model.params[name] = p - G.alpha * dp model.params[name].requires_grad = True grad_ind = G.n_gradient_steps with t.no_grad(): # domain adaptation if hasattr(model, "is_autoregressive") and model.is_autoregressive: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0, labels.view(G.k_shot, 1, 1)) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). elif hasattr(model, "is_recurrent") and model.is_recurrent: h0 = model.h0_init() ys, ht = model(xs.view(G.k_shot, 1, 1), h0) ys = ys.squeeze(-1) # ys:Size(5, batch_n:1, 1). else: ys = model(xs.unsqueeze(-1)) ht = None loss = mse(ys, labels.unsqueeze(-1)) logger.log_keyvalue(ep_ind, f"{task_ind}-grad-{grad_ind}-loss", loss.item(), silent=_silent) if callable(test_fn) and \ ep_ind % G.test_interval == 0 and grad_ind in G.test_grad_steps: test_fn(model, task, task_id=task_ind, epoch=ep_ind, grad_step=grad_ind, h0=ht, silent=_silent) # Compute REPTILE 1st-order gradient with t.no_grad(): for name, p in original_ps.items(): # let's do the division at the end. p.grad = (0 if p.grad is None else p.grad) + ( p - model.params[name]) # / G.task_batch_n with t.no_grad(): for name, p in original_ps.items(): # let's do the division at the end. p.grad /= G.task_batch_n model.params.update(original_ps) meta_optimizer.step() meta_optimizer.zero_grad() if G.save_interval and ep_ind % G.save_interval == 0: logger.log_module(ep_ind, **{type(model).__name__: model}) logger.flush()