def __init__(self, dataset, num_classes, num_inst, meta_batch_size, meta_step_size, inner_batch_size, inner_step_size, num_updates, num_inner_updates, loss_fn): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.num_inst = num_inst self.meta_batch_size = meta_batch_size self.meta_step_size = meta_step_size self.inner_batch_size = inner_batch_size self.inner_step_size = inner_step_size self.num_updates = num_updates self.num_inner_updates = num_inner_updates self.loss_fn = loss_fn # Make the nets #TODO: don't actually need two nets # Note this code base has two nets - one for the general, one for the inner loop num_input_channels = 1 if self.dataset == 'mnist' else 3 self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels) self.net.cuda() self.fast_net = InnerLoop(num_classes, self.loss_fn, self.num_inner_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, num_input_channels) self.fast_net.cuda() # This will be only over the net, not the fast net. self.opt = Adam(self.net.parameters(), lr=meta_step_size)
def __init__(self, dataset, num_classes, num_inst, meta_batch_size, meta_step_size, inner_batch_size, inner_step_size, num_updates, num_inner_updates, loss_fn): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.num_inst = num_inst self.meta_batch_size = meta_batch_size self.meta_step_size = meta_step_size self.inner_batch_size = inner_batch_size self.inner_step_size = inner_step_size self.num_updates = num_updates self.num_inner_updates = num_inner_updates self.loss_fn = loss_fn # Make the nets #TODO: don't actually need two nets num_input_channels = 1 if self.dataset == 'mnist' else 3 self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels) self.net.cuda() self.fast_net = InnerLoop(num_classes, self.loss_fn, self.num_inner_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, num_input_channels) self.fast_net.cuda() self.opt = Adam(self.net.parameters())
def Evaluate(self, individual): """ Adapted inner loop of Model Agnostic Meta-Learning to be Baldwinian a la Fernando et al. 2018. Source: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py """ inner_net = InnerLoop(self.num_classes, self.loss_fn, self.num_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, self.num_input_channels) #tasks = {} #tasks['t1'] = self.get_task("../data/{}".format(self.dataset), self.num_classes, self.num_inst) for t in range(self.meta_batch_size): # Outer-loop is completed by NES for G generations inner_net.copy_weights(individual['network']) metrics = inner_net.forward(self.tasks[str(t)]) # Want validation accuracy for fitness (tr_loss, tr_acc, val_loss, val_acc): individual['fitness'] += metrics[-1] keys = ("_tr_loss", "_tr_acc", "_val_loss", "_val_acc") idx = 0 for k in keys: key = "t" + str(t) + k individual[key] = metrics[idx] idx += 1 self.Record_Performance(individual)
def __init__(self, meta_dataset, fs_dataset, K, meta_lr, inner_lr, layer, hidden, tissue_num, meta_batch_size, inner_batch_size, num_updates, num_inner_updates, tissue_index_list, patience=3, num_trials=10): super(self.__class__, self).__init__() self.meta_dataset = meta_dataset self.fs_dataset = fs_dataset self.meta_batch_size = meta_batch_size self.inner_batch_size = inner_batch_size self.num_updates = num_updates self.num_inner_updates = num_inner_updates self.num_trials = num_trials self.hidden = hidden self.patience = patience self.feature_num = self.fs_dataset.feature.shape[1] self.K = K self.meta_lr = meta_lr self.inner_lr = inner_lr self.layer = layer self.hidden = hidden self.tissue_index_list = tissue_index_list self.tissue_num = tissue_num self.observed_tissue_model = mlp(self.feature_num, layer, hidden) self.observed_opt = torch.optim.Adam( self.observed_tissue_model.parameters(), lr=self.meta_lr, betas=(0.9, 0.99), eps=1e-05) self.inner_net = InnerLoop(self.num_inner_updates, self.inner_lr, self.feature_num, layer, hidden) #torch.cuda.manual_seed(args.seed) self.observed_tissue_model.cuda() self.inner_net.cuda()
def __init__(self, log, tb_writer, args): super(self.__class__, self).__init__() self.log = log self.tb_writer = tb_writer self.args = args self.loss_fn = MSELoss() self.net = OmniglotNet(self.loss_fn, args).to(device) self.fast_net = InnerLoop(self.loss_fn, args).to(device) self.opt = Adam(self.net.parameters(), lr=args.meta_lr) self.sampler = BatchSampler(args) self.memory = ReplayBuffer()
def Evaluate(individual): """ Adapted inner loop of Model Agnostic Meta-Learning to be Baldwinian a la Fernando et al. 2018. Source: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py """ #tasks['t1'] = self.get_task("../data/{}".format(self.dataset), self.num_classes, self.num_inst) inner_net = InnerLoop(5, CrossEntropyLoss(), 3, 0.01, 100, 10, 3) for t in range(10): task = Get_Task("../data/{}".format('omniglot'), 5, 10) # Outer-loop is completed by NES for G generations inner_net.copy_weights(individual['network']) metrics = inner_net.forward(task) # Want validation accuracy for fitness (tr_loss, tr_acc, val_loss, val_acc): print(metrics)
def __init__(self, num_way, num_shot, classifier_param, loss_fn, inner_step_size, num_inner_updates, has_bias=False, alpha=[5, 1, 1], num_layer=1): super(self.__class__, self).__init__() self.num_way = num_way self.num_shot = num_shot self.num_inner_updates = num_inner_updates self.inner_step_size = inner_step_size self.alpha = alpha # Make the nets #TODO: don't actually need two nets net_param_opt = { 'userelu': False, 'in_planes': 3, 'out_planes': [64, 64, 64, 64], 'num_stages': 4, 'num_dim': 64 * 5 * 5, 'num_Kall': 64 } num_dim = net_param_opt['num_dim'] self.model = {} self.model['embedding_net'] = ConvNet(net_param_opt) self.model['embedding_net'].cuda() self.model['classifier_net'] = classifier_param self.model['test_net'] = InnerLoop(num_dim, 5, loss_fn, self.num_inner_updates, self.inner_step_size, 1, classifier_param, has_bias, self.alpha, num_layer) self.model['test_net'].cuda()
class MetaLearner(object): def __init__(self, dataset, num_classes, num_inst, meta_batch_size, meta_step_size, inner_batch_size, inner_step_size, num_updates, num_inner_updates, loss_fn): super(self.__class__, self).__init__() self.dataset = dataset self.num_classes = num_classes self.num_inst = num_inst self.meta_batch_size = meta_batch_size self.meta_step_size = meta_step_size self.inner_batch_size = inner_batch_size self.inner_step_size = inner_step_size self.num_updates = num_updates self.num_inner_updates = num_inner_updates self.loss_fn = loss_fn # Make the nets #TODO: don't actually need two nets num_input_channels = 1 if self.dataset == 'mnist' else 3 self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels) self.net.cuda() self.fast_net = InnerLoop(num_classes, self.loss_fn, self.num_inner_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, num_input_channels) self.fast_net.cuda() self.opt = Adam(self.net.parameters()) def get_task(self, root, n_cl, n_inst, split='train'): if 'mnist' in root: return MNISTTask(root, n_cl, n_inst, split) elif 'omniglot' in root: return OmniglotTask(root, n_cl, n_inst, split) else: print 'Unknown dataset' raise (Exception) def meta_update(self, task, ls): print '\n Meta update \n' loader = get_data_loader(task, self.inner_batch_size, split='val') in_, target = loader.__iter__().next() # We use a dummy forward / backward pass to get the correct grads into self.net loss, out = forward_pass(self.net, in_, target) # Unpack the list of grad dicts gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()} # Register a hook on each parameter in the net that replaces the current dummy grad # with our grads accumulated across the meta-batch hooks = [] for (k, v) in self.net.named_parameters(): def get_closure(): key = k def replace_grad(grad): return gradients[key] return replace_grad hooks.append(v.register_hook(get_closure())) # Compute grads for current step, replace with summed gradients as defined by hook self.opt.zero_grad() loss.backward() # Update the net parameters with the accumulated gradient according to optimizer self.opt.step() # Remove the hooks before next training phase for h in hooks: h.remove() def test(self): num_in_channels = 1 if self.dataset == 'mnist' else 3 test_net = OmniglotNet(self.num_classes, self.loss_fn, num_in_channels) mtr_loss, mtr_acc, mval_loss, mval_acc = 0.0, 0.0, 0.0, 0.0 # Select ten tasks randomly from the test set to evaluate on for _ in range(10): # Make a test net with same parameters as our current net test_net.copy_weights(self.net) test_net.cuda() test_opt = SGD(test_net.parameters(), lr=self.inner_step_size) task = self.get_task('../data/{}'.format(self.dataset), self.num_classes, self.num_inst, split='test') # Train on the train examples, using the same number of updates as in training train_loader = get_data_loader(task, self.inner_batch_size, split='train') for i in range(self.num_inner_updates): in_, target = train_loader.__iter__().next() loss, _ = forward_pass(test_net, in_, target) test_opt.zero_grad() loss.backward() test_opt.step() # Evaluate the trained model on train and val examples tloss, tacc = evaluate(test_net, train_loader) val_loader = get_data_loader(task, self.inner_batch_size, split='val') vloss, vacc = evaluate(test_net, val_loader) mtr_loss += tloss mtr_acc += tacc mval_loss += vloss mval_acc += vacc mtr_loss = mtr_loss / 10 mtr_acc = mtr_acc / 10 mval_loss = mval_loss / 10 mval_acc = mval_acc / 10 print '-------------------------' print 'Meta train:', mtr_loss, mtr_acc print 'Meta val:', mval_loss, mval_acc print '-------------------------' del test_net return mtr_loss, mtr_acc, mval_loss, mval_acc def _train(self, exp): ''' debugging function: learn two tasks ''' task1 = self.get_task('../data/{}'.format(self.dataset), self.num_classes, self.num_inst) task2 = self.get_task('../data/{}'.format(self.dataset), self.num_classes, self.num_inst) for it in range(self.num_updates): grads = [] for task in [task1, task2]: # Make sure fast net always starts with base weights self.fast_net.copy_weights(self.net) _, g = self.fast_net.forward(task) grads.append(g) self.meta_update(task, grads) def train(self, exp): tr_loss, tr_acc, val_loss, val_acc = [], [], [], [] mtr_loss, mtr_acc, mval_loss, mval_acc = [], [], [], [] for it in range(self.num_updates): # Evaluate on test tasks mt_loss, mt_acc, mv_loss, mv_acc = self.test() mtr_loss.append(mt_loss) mtr_acc.append(mt_acc) mval_loss.append(mv_loss) mval_acc.append(mv_acc) # Collect a meta batch update grads = [] tloss, tacc, vloss, vacc = 0.0, 0.0, 0.0, 0.0 for i in range(self.meta_batch_size): task = self.get_task('../data/{}'.format(self.dataset), self.num_classes, self.num_inst) self.fast_net.copy_weights(self.net) metrics, g = self.fast_net.forward(task) (trl, tra, vall, vala) = metrics grads.append(g) tloss += trl tacc += tra vloss += vall vacc += vala # Perform the meta update print 'Meta update', it self.meta_update(task, grads) # Save a model snapshot every now and then if it % 500 == 0: torch.save(self.net.state_dict(), '../output/{}/train_iter_{}.pth'.format(exp, it)) # Save stuff tr_loss.append(tloss / self.meta_batch_size) tr_acc.append(tacc / self.meta_batch_size) val_loss.append(vloss / self.meta_batch_size) val_acc.append(vacc / self.meta_batch_size) np.save('../output/{}/tr_loss.npy'.format(exp), np.array(tr_loss)) np.save('../output/{}/tr_acc.npy'.format(exp), np.array(tr_acc)) np.save('../output/{}/val_loss.npy'.format(exp), np.array(val_loss)) np.save('../output/{}/val_acc.npy'.format(exp), np.array(val_acc)) np.save('../output/{}/meta_tr_loss.npy'.format(exp), np.array(mtr_loss)) np.save('../output/{}/meta_tr_acc.npy'.format(exp), np.array(mtr_acc)) np.save('../output/{}/meta_val_loss.npy'.format(exp), np.array(mval_loss)) np.save('../output/{}/meta_val_acc.npy'.format(exp), np.array(mval_acc))
class MetaLearner(object): def __init__(self, meta_dataset, fs_dataset, K, meta_lr, inner_lr, layer, hidden, tissue_num, meta_batch_size, inner_batch_size, num_updates, num_inner_updates, tissue_index_list, patience=3, num_trials=10): super(self.__class__, self).__init__() self.meta_dataset = meta_dataset self.fs_dataset = fs_dataset self.meta_batch_size = meta_batch_size self.inner_batch_size = inner_batch_size self.num_updates = num_updates self.num_inner_updates = num_inner_updates self.num_trials = num_trials self.hidden = hidden self.patience = patience self.feature_num = self.fs_dataset.feature.shape[1] self.K = K self.meta_lr = meta_lr self.inner_lr = inner_lr self.layer = layer self.hidden = hidden self.tissue_index_list = tissue_index_list self.tissue_num = tissue_num self.observed_tissue_model = mlp(self.feature_num, layer, hidden) self.observed_opt = torch.optim.Adam( self.observed_tissue_model.parameters(), lr=self.meta_lr, betas=(0.9, 0.99), eps=1e-05) self.inner_net = InnerLoop(self.num_inner_updates, self.inner_lr, self.feature_num, layer, hidden) #torch.cuda.manual_seed(args.seed) self.observed_tissue_model.cuda() self.inner_net.cuda() def zero_shot_test(self, unseen_train_loader, unseen_vali_loader, unseen_test_loader): unseen_tissue_model = mlp(self.feature_num, self.layer, self.hidden) # First need to copy the original meta learning model unseen_tissue_model.copy_weights(self.observed_tissue_model) unseen_tissue_model.cuda() unseen_tissue_model.eval() train_performance = evaluate_cv(unseen_tissue_model, unseen_train_loader) vali_performance = evaluate_cv(unseen_tissue_model, unseen_vali_loader) test_performance = evaluate_cv(unseen_tissue_model, unseen_test_loader) return train_performance, vali_performance, test_performance, np.mean( tissue_loss) def meta_update(self, test_loader, ls): #Walk in 'Meta update' function in_, target = test_loader.__iter__().next() # We use a dummy forward / backward pass to get the correct grads into self.net loss, out = forward_pass(self.observed_tissue_model, in_, target) # Unpack the list of grad dicts gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()} #for k, val, in gradients.items(): # gradients[k] = val / args.meta_batch_size # print k,':',gradients[k] # Register a hook on each parameter in the net that replaces the current dummy grad # with our grads accumulated across the meta-batch hooks = [] for (k, v) in self.observed_tissue_model.named_parameters(): def get_closure(): key = k def replace_grad(grad): return gradients[key] return replace_grad hooks.append(v.register_hook(get_closure())) # Compute grads for current step, replace with summed gradients as defined by hook self.observed_opt.zero_grad() loss.backward() # Update the net parameters with the accumulated gradient according to optimizer self.observed_opt.step() # Remove the hooks before next training phase for h in hooks: h.remove() def unseen_tissue_learn(self, unseen_train_loader, unseen_test_loader): unseen_tissue_model = mlp(self.feature_num, self.layer, self.hidden) # First need to copy the original meta learning model unseen_tissue_model.copy_weights(self.observed_tissue_model) unseen_tissue_model.cuda() #unseen_tissue_model.train() unseen_tissue_model.eval() unseen_opt = torch.optim.SGD(unseen_tissue_model.parameters(), lr=self.inner_lr) #unseen_opt = torch.optim.Adam(unseen_tissue_model.parameters(), lr=self.inner_lr, betas=(0.9, 0.99), eps=1e-05) # Here test_feature and test_label contains only one tissue info #unseen_train_loader, unseen_test_loader = get_unseen_data_loader(test_feature, test_label, K, args.inner_batch_size) for i in range(self.num_inner_updates): in_, target = unseen_train_loader.__iter__().next() loss, _ = forward_pass(unseen_tissue_model, in_, target) unseen_opt.zero_grad() loss.backward() unseen_opt.step() # Test on the rest of cell lines in this tissue (unseen_test_loader) mtrain_loss, mtrain_pear_corr, mtrain_spearman_corr, _, _ = evaluate_new( unseen_tissue_model, unseen_train_loader, 1) mtest_loss, mtest_pear_corr, mtest_spearman_corr, test_prediction, test_true_label = evaluate_new( unseen_tissue_model, unseen_test_loader, 0) return mtrain_loss, mtrain_pear_corr, mtrain_spearman_corr, mtest_loss, mtest_pear_corr, mtest_spearman_corr, test_prediction, test_true_label def train(self): best_train_loss_test_corr, best_train_corr_test_corr = 0, 0 best_train_corr_test_scorr, best_train_scorr_test_scorr = 0, 0 best_train_corr_model = '' unseen_train_loader, unseen_test_loader = get_unseen_data_loader( self.fs_dataset.feature, self.fs_dataset.label, self.K) #unseen_train_loader, unseen_test_loader = get_unseen_data_loader( self.fs_dataset.feature, self.fs_dataset.label, self.fs_catdata, self.K ) # Here the training process starts best_fewshot_train_corr, best_fewshot_train_loss = -2, 1000 best_fewshot_test_corr, best_fewshot_test_loss = -2, 1000 best_train_loss_epoch, best_train_corr_epoch = 0, 0 best_fewshot_train_spearman_corr, best_fewshot_test_spearman_corr = -2, -2 best_train_spearman_corr_epoch = 0 train_loss, train_corr = np.zeros((self.num_updates, )), np.zeros( (self.num_updates, )) test_loss, test_corr = np.zeros((self.num_updates, )), np.zeros( (self.num_updates, )) train_spearman_corr, test_spearman_corr = np.zeros( (self.num_updates, )), np.zeros((self.num_updates, )) for epoch in range(self.num_updates): # Collect a meta batch update grads = [] meta_train_loss, meta_train_corr, meta_val_loss, meta_val_corr = np.zeros( (self.meta_batch_size, )), np.zeros( (self.meta_batch_size, )), np.zeros( (self.meta_batch_size, )), np.zeros( (self.meta_batch_size, )) self.inner_net.copy_weights(self.observed_tissue_model) for i in range(self.meta_batch_size): observed_train_loader, observed_test_loader = get_observed_data_loader( self.meta_dataset.feature, self.meta_dataset.label, self.tissue_index_list, self.K, self.inner_batch_size, self.tissue_num) #self.inner_net.copy_weights( self.observed_tissue_model ) metrics, g = self.inner_net.forward(observed_train_loader, observed_test_loader) grads.append(g) meta_train_loss[i], meta_train_corr[i], meta_val_loss[ i], meta_val_corr[i] = metrics # Perform the meta update self.meta_update(observed_test_loader, grads) #meta_train_loss_mean, meta_train_corr_mean, meta_val_loss_mean, meta_val_corr_mean = meta_train_loss.mean(), meta_train_corr.mean(), meta_val_loss.mean(), meta_val_corr.mean() ## Evaluate K shot test tasks train_loss[epoch], train_corr[epoch], train_spearman_corr[ epoch], test_loss[epoch], test_corr[epoch], test_spearman_corr[ epoch], _, _ = self.unseen_tissue_learn( unseen_train_loader, unseen_test_loader) if test_loss[epoch] < best_fewshot_test_loss: best_fewshot_test_loss = test_loss[epoch] if test_corr[epoch] > best_fewshot_test_corr: best_fewshot_test_corr = test_corr[epoch] if train_loss[epoch] < best_fewshot_train_loss: best_fewshot_train_loss = train_loss[epoch] best_train_loss_epoch = epoch if train_corr[epoch] > best_fewshot_train_corr: best_fewshot_train_corr = train_corr[epoch] best_train_corr_epoch = epoch best_train_corr_model = self.observed_tissue_model if train_spearman_corr[epoch] > best_fewshot_train_spearman_corr: best_fewshot_train_spearman_corr = train_spearman_corr[epoch] best_train_spearman_corr_epoch = epoch if test_spearman_corr[epoch] > best_fewshot_test_spearman_corr: best_fewshot_test_spearman_corr = test_spearman_corr[epoch] best_test_spearman_epoch = epoch print 'Few shot', epoch, 'train_loss:', float( '%.3f' % train_loss[epoch]), 'train_pearson:', float( '%.3f' % train_corr[epoch]), 'train_spearman:', float( '%.3f' % train_spearman_corr[epoch]), print 'test_loss:', float( '%.3f' % test_loss[epoch]), 'test_pearson:', float( '%.3f' % test_corr[epoch]), 'test_spearman:', float( '%.3f' % test_spearman_corr[epoch]) best_train_loss_test_corr = test_corr[best_train_loss_epoch] best_train_corr_test_corr = test_corr[best_train_corr_epoch] best_train_corr_test_scorr = test_spearman_corr[best_train_corr_epoch] best_train_scorr_test_scorr = test_spearman_corr[ best_train_spearman_corr_epoch] print '--trial summerize--', 'best_train_loss_test_corr:', float( '%.3f' % best_train_loss_test_corr), 'best_train_corr_test_corr', float( '%.3f' % best_train_corr_test_corr ), 'best_train_corr_test_scorr', float( '%.3f' % best_train_corr_test_scorr ), 'best_train_scorr_test_corr', float('%.3f' % best_train_scorr_test_scorr) return best_train_loss_test_corr, best_train_corr_test_corr, best_train_corr_test_scorr, best_train_scorr_test_scorr, best_train_corr_model
def init_proc(self): global net net = InnerLoop(self.num_classes, self.loss_fn, self.num_inner_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, self.num_input_channels) net.cuda()
class MetaLearner(object): def __init__(self, log, tb_writer, args): super(self.__class__, self).__init__() self.log = log self.tb_writer = tb_writer self.args = args self.loss_fn = MSELoss() self.net = OmniglotNet(self.loss_fn, args).to(device) self.fast_net = InnerLoop(self.loss_fn, args).to(device) self.opt = Adam(self.net.parameters(), lr=args.meta_lr) self.sampler = BatchSampler(args) self.memory = ReplayBuffer() def meta_update(self, episode_i, ls): in_ = episode_i.observations[:, :, 0] target = episode_i.rewards[:, :, 0] # We use a dummy forward / backward pass to get the correct grads into self.net loss, out = forward_pass(self.net, in_, target) # Unpack the list of grad dicts gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()} # Register a hook on each parameter in the net that replaces the current dummy grad # with our grads accumulated across the meta-batch hooks = [] for (k, v) in self.net.named_parameters(): def get_closure(): key = k def replace_grad(grad): return gradients[key] return replace_grad hooks.append(v.register_hook(get_closure())) # Compute grads for current step, replace with summed gradients as defined by hook self.opt.zero_grad() loss.backward() # Update the net parameters with the accumulated gradient according to optimizer self.opt.step() # Remove the hooks before next training phase for h in hooks: h.remove() def test(self, i_task, episode_i_): predictions_ = [] for i_agent in range(self.args.n_agent): test_net = OmniglotNet(self.loss_fn, self.args).to(device) # Make a test net with same parameters as our current net test_net.copy_weights(self.net) test_opt = SGD(test_net.parameters(), lr=self.args.fast_lr) episode_i = self.memory.storage[i_task - 1] # Train on the train examples, using the same number of updates as in training for i in range(self.args.fast_num_update): in_ = episode_i.observations[:, :, i_agent] target = episode_i.rewards[:, :, i_agent] loss, _ = forward_pass(test_net, in_, target) print("loss {} at {}".format(loss, i_task)) test_opt.zero_grad() loss.backward() test_opt.step() # Evaluate the trained model on train and val examples tloss, _ = evaluate(test_net, episode_i, i_agent) vloss, prediction_ = evaluate(test_net, episode_i_, i_agent) mtr_loss = tloss / 10. mval_loss = vloss / 10. print('-------------------------') print('Meta train:', mtr_loss) print('Meta val:', mval_loss) print('-------------------------') del test_net predictions_.append(prediction_) visualize(episode_i, episode_i_, predictions_, i_task, self.args) def train(self): for i_task in range(10000): # Sample episode from current task self.sampler.reset_task(i_task) episodes = self.sampler.sample() # Add to memory self.memory.add(i_task, episodes) # Evaluate on test tasks if len(self.memory) > 1: self.test(i_task, episodes) # Collect a meta batch update if len(self.memory) > 2: meta_grads = [] for i in range(self.args.meta_batch_size): if i == 0: episodes_i = self.memory.storage[i_task - 1] episodes_i_ = self.memory.storage[i_task] else: episodes_i, episodes_i_ = self.memory.sample() self.fast_net.copy_weights(self.net) for i_agent in range(self.args.n_agent): meta_grad = self.fast_net.forward( episodes_i, episodes_i_, i_agent) meta_grads.append(meta_grad) # Perform the meta update self.meta_update(episodes_i, meta_grads)