def test_update_parameters(model): """ The loss function (with respect to the weights of the model w) is defined as f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2 with w = [2, 3, 5]. The gradient of the function f with respect to w, and evaluated at w = [2, 3, 5], is: df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23 df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46 df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69 The updated parameter w' is then given by one step of gradient descent, with step size 0.5: w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 * 23 = -9.5 w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 46 = -20 w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 68 = -29.5 """ train_inputs = torch.tensor([[1., 2., 3.]]) train_loss = 0.5 * (model(train_inputs)**2) params = gradient_update_parameters(model, train_loss, params=None, step_size=0.5, first_order=False) assert train_loss.item() == 264.5 assert list(params.keys()) == ['weight'] assert torch.all( params['weight'].data == torch.tensor([[-9.5, -20., -29.5]])) """ The new loss function (still with respect to the weights of the model w) is defined as: g(w) = 0.5 * (4 * w'_1 + 5 * w'_2 + 6 * w'_3) ** 2 = 0.5 * (4 * (w_1 - 0.5 * df / dw_1) + 5 * (w_2 - 0.5 * df / dw_2) + 6 * (w_3 - 0.5 * df / dw_3)) ** 2 = 0.5 * (4 * (w_1 - 0.5 * 1 * (1 * w_1 + 2 * w_2 + 3 * w_3)) + 5 * (w_2 - 0.5 * 2 * (1 * w_1 + 2 * w_2 + 3 * w_3)) + 6 * (w_3 - 0.5 * 3 * (1 * w_1 + 2 * w_2 + 3 * w_3))) ** 2 = 0.5 * ((4 - 4 * 0.5 - 5 * 1.0 - 6 * 1.5) * w_1 + (5 - 4 * 1.0 - 5 * 2.0 - 6 * 3.0) * w_2 + (6 - 4 * 1.5 - 5 * 3.0 - 6 * 4.5) * w_3) ** 2 = 0.5 * (-12 * w_1 - 27 * w_2 - 42 * w_3) ** 2 Therefore the gradient of the function g with respect to w (and evaluated at w = [2, 3, 5]) is: dg / dw_1 = -12 * (-12 * w_1 - 27 * w_2 - 42 * w_3) = 3780 dg / dw_2 = -27 * (-12 * w_1 - 27 * w_2 - 42 * w_3) = 8505 dg / dw_3 = -42 * (-12 * w_1 - 27 * w_2 - 42 * w_3) = 13230 """ test_inputs = torch.tensor([[4., 5., 6.]]) test_loss = 0.5 * (model(test_inputs, params=params)**2) grads = torch.autograd.grad(test_loss, model.parameters()) assert test_loss.item() == 49612.5 assert len(grads) == 1 assert torch.all(grads[0].data == torch.tensor([[3780., 8505., 13230.]]))
def update_policy(self, trajectorys): # self.optimizer_p.zero_grad() self.policy_net.zero_grad() states = trajectorys[self.id].get_state() actions = trajectorys[self.id].get_action() rewards_env = trajectorys[self.id].get_reward_env() rewards_given = trajectorys[self.id].get_reward_from() loss_policy = [] loss_critic = [] loss_entropy = [] R = 0 returns_env = [] returns_given = [] # Compute policy loss # Get the V value of timestep from critic logits, V_s = self.policy_net(states) prob = F.softmax(logits, dim=-1) log_prob = F.log_softmax(logits, dim=-1) V_s = V_s.view(-1) for r in rewards_env[::-1]: R = r + gamma * R returns_env.insert(0, R) for r in rewards_given[::-1]: R = r + gamma * R returns_given.insert(0, R) returns_env = torch.Tensor(returns_env).detach() returns_given = torch.cat(returns_given, dim=0) returns = returns_env + returns_given # returns = returns_env Q_s_a = returns A_s_a = Q_s_a - V_s # compute policy loss loss_entropy_p = - log_prob * prob loss_entropy_p = loss_entropy_p.mean() loss_entropy.append(loss_entropy_p) log_prob_act = torch.stack([log_prob[i][actions[i]] for i in range(len(actions))], dim=0) loss_policy_p = - torch.dot(A_s_a, log_prob_act).view(1) / len(prob) loss_policy.append(loss_policy_p) # Compute critic loss # loss_critic_p = (returns - V_s).pow(2).mean() loss_critic_p = A_s_a.pow(2).mean() loss_critic.append(loss_critic_p) loss_policy = torch.stack(loss_policy).mean() loss_critic = torch.stack(loss_critic).mean() loss_entropy = torch.stack(loss_entropy).mean() loss = loss_policy + 0.5 * loss_critic + 0.01 * loss_entropy # loss.backward(retain_graph=True) # self.optimizer_p.step() self.new_params = gradient_update_parameters(self.policy_net, loss, step_size=step_size)
def test_policy_update(config): agents = [] for i in range(config.env.n_agents): agents.append(Actor(i, 7, config.env.n_agents)) input = torch.Tensor([1, 1, 1, 1, 1, 1, 1]) agent0 = agents[0] agent1 = agents[1] output0 = agent0.policy_net(input) output1 = agent1.policy_net(input) loss0 = 1 - output0.sum() loss1 = 2 - output1.sum() print(loss0) print(loss1) agent0.new_params = gradient_update_parameters(agent0.policy_net, loss0, step_size=0.5) agent1.new_params = gradient_update_parameters(agent1.policy_net, loss1, step_size=0.5) output0 = agent0.policy_net(input, agent0.new_params) output1 = agent1.policy_net(input, agent1.new_params) loss0 = 1 - output0.sum() loss1 = 2 - output1.sum() print(loss0) print(loss1) agent0.update_to_new_params() agent1.update_to_new_params() output0 = agent0.policy_net(input) output1 = agent1.policy_net(input) loss0 = 1 - output0.sum() loss1 = 2 - output1.sum() print(loss0) print(loss1)
def test_update_parameters_first_order(model): """ The loss function (with respect to the weights of the model w) is defined as f(w) = 0.5 * (4 * w_1 + 5 * w_2 + 6 * w_3) ** 2 with w = [2, 3, 5]. The gradient of the function f with respect to w, and evaluated at w = [2, 3, 5] is: df / dw_1 = 4 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 212 df / dw_2 = 5 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 265 df / dw_3 = 6 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 318 The updated parameter w' is then given by one step of gradient descent, with step size 0.5: w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 * 212 = -104 w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 265 = -129.5 w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 318 = -154 """ train_inputs = torch.tensor([[4., 5., 6.]]) train_loss = 0.5 * (model(train_inputs)**2) params = gradient_update_parameters(model, train_loss, params=None, step_size=0.5, first_order=True) assert train_loss.item() == 1404.5 assert list(params.keys()) == ['weight'] assert torch.all( params['weight'].data == torch.tensor([[-104., -129.5, -154.]])) """ The new loss function (still with respect to the weights of the model w) is defined as: g(w) = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2 Since we computed w' with the first order approximation, the gradient of the function g with respect to w, and evaluated at w = [2, 3, 5], is: dg / dw_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -825 dg / dw_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -1650 dg / dw_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -2475 """ test_inputs = torch.tensor([[1., 2., 3.]]) test_loss = 0.5 * (model(test_inputs, params=params)**2) grads = torch.autograd.grad(test_loss, model.parameters()) assert test_loss.item() == 340312.5 assert len(grads) == 1 assert torch.all(grads[0].data == torch.tensor([[-825., -1650., -2475.]]))
def get_adapted_params(model, test_batch): test_inputs, test_targets = test_batch['train'] test_in, test_target = test_inputs[0], test_targets[0] test_out = model(test_in) inner_loss = get_loss(test_out, test_target) model.zero_grad() with torch.no_grad(): params = gradient_update_parameters(model, inner_loss, step_size=step_size, first_order=first_order) test_out = model(test_in, params=params) outer_loss = get_loss(test_out, test_target) return params, outer_loss.item()
def test_dataparallel_params_maml(model): device = torch.device('cuda:0') model = DataParallel(model) model.to(device=device) train_inputs = torch.rand(5, 2).to(device=device) train_outputs = model(train_inputs) inner_loss = train_outputs.sum() # Dummy loss params = gradient_update_parameters(model, inner_loss) test_inputs = torch.rand(5, 2).to(device=device) test_outputs = model(test_inputs, params=params) assert test_outputs.shape == (5, 1) assert test_outputs.device == device outer_loss = test_outputs.sum() # Dummy loss outer_loss.backward()
def meta_train(args, metaDataloader): model = RegressionNeuralNetwork(args['in_channels'], hidden1_size=args['hidden1_size'], hidden2_size=args['hidden2_size']) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=args['beta']) loss_record = [] # training loop for it_outer in range(args['num_it_outer']): model.zero_grad() train_dataloader = metaDataloader['train'] test_dataloader = metaDataloader['test'] outer_loss = torch.tensor(0., dtype=torch.float) for task in train_dataloader: iterator = iter(train_dataloader[task]) train_sample = iterator.next() # get true h value # h_value = torch.tensor(train_sample[:,-1], dtype=torch.float) h_value = train_sample[:, -1].clone().detach().to(dtype=torch.float) # get input # input_value = torch.tensor(train_sample[:,:-1], dtype=torch.float) input_value = train_sample[:, :-1].clone().detach().to( dtype=torch.float) # train_h_value = model(input_value) inner_loss = F.mse_loss(train_h_value.view(-1, 1), h_value.view(-1, 1)) model.zero_grad() # print('It {}, task {}, Start updating parameters'.format(it_outer, task)) params = gradient_update_parameters( model, inner_loss, step_size=args['alpha'], first_order=args['first_order']) # adaptation # get test sample test_iterator = iter(test_dataloader[task]) test_sample = test_iterator.next() # h_value2 = torch.tensor(test_sample[:,-1], dtype=torch.float) h_value2 = test_sample[:, -1].clone().detach().to(dtype=torch.float) # test_input_value = torch.tensor(test_sample[:,:-1], dtype=torch.float) test_input_value = test_sample[:, :-1].clone().detach().to( dtype=torch.float) test_h_value = model(test_input_value, params=params) outer_loss += F.mse_loss(test_h_value.view(-1, 1), h_value2.view(-1, 1)) outer_loss.div_(args['num_tasks']) outer_loss.backward() meta_optimizer.step() loss_record.append(outer_loss.detach()) if it_outer % 50 == 0: print('It {}, outer traning loss: {}'.format(it_outer, outer_loss)) # print the loss plot plt.plot(loss_record) plt.title('Outer Training Loss (MSE Loss) in MAML') plt.xlabel('Iteration number') plt.show() # save model if args['output_model'] is not None: with open(args['output_model'], 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f)
def train(args): logger.warning( 'This script is an example to showcase the MetaModule and ' 'data-loading features of Torchmeta, and as such has been ' 'very lightly tested. For a better tested implementation of ' 'Model-Agnostic Meta-Learning (MAML) using Torchmeta with ' 'more features (including multi-step adaptation and ' 'different datasets), please check `https://github.com/' 'tristandeleu/pytorch-maml`.') dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways, shuffle=True, test_shots=15, meta_train=True, download=args.download) dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) model = ConvolutionalNeuralNetwork(1, args.num_ways, hidden_size=args.hidden_size) model.to(device=args.device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop with tqdm(dataloader, total=args.num_batches) as pbar: for batch_idx, batch in enumerate(pbar): model.zero_grad() train_inputs, train_targets = batch['train'] train_inputs = train_inputs.to(device=args.device) train_targets = train_targets.to(device=args.device) test_inputs, test_targets = batch['test'] test_inputs = test_inputs.to(device=args.device) test_targets = test_targets.to(device=args.device) outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for task_idx, (train_input, train_target, test_input, test_target) in enumerate( zip(train_inputs, train_targets, test_inputs, test_targets)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) model.zero_grad() params = gradient_update_parameters( model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) outer_loss.div_(args.batch_size) accuracy.div_(args.batch_size) outer_loss.backward() meta_optimizer.step() pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) if batch_idx >= args.num_batches: break # Save model if args.output_folder is not None: filename = os.path.join( args.output_folder, 'maml_omniglot_' '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways)) with open(filename, 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f)
def train(): transform = transforms.Compose( [transforms.Resize(84), transforms.ToTensor()]) dataset_transform = ClassSplitter(shuffle=True, num_train_per_class=5, num_test_per_class=5) dataset = MiniImagenet('', transform=transform, num_classes_per_task=5, target_transform=Categorical(num_classes=5), meta_split="train", dataset_transform=dataset_transform) dataloader = BatchMetaDataLoader(dataset, batch_size=1, shuffle=True) model = ModelConvMiniImagenet(5) model.to(device='cuda') model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) accuracy_l = list() with tqdm(dataloader, total=1000) as pbar: for batch_idx, batch in enumerate(pbar): model.zero_grad() train_inputs, train_targets = batch['train'] train_inputs = train_inputs.to(device='cuda') train_targets = train_targets.to(device='cuda') test_inputs, test_targets = batch['test'] test_inputs = test_inputs.to(device='cuda') test_targets = test_targets.to(device='cuda') outer_loss = torch.tensor(0., device='cuda') accuracy = torch.tensor(0., device='cuda') for task_idx, (train_input, train_target, test_input, test_target) in enumerate( zip(train_inputs, train_targets, test_inputs, test_targets)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) model.zero_grad() params = gradient_update_parameters(model, inner_loss) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) outer_loss.div_(1) accuracy.div_(1) outer_loss.backward() meta_optimizer.step() accuracy_l.append(accuracy.item()) pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) if (batch_idx >= 1000): break plt.plot(accuracy_l) plt.show()
def test_multiple_update_parameters(model): """ The loss function (with respect to the weights of the model w) is defined as f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2 with w = [2, 3, 5]. The gradient of f with respect to w is: df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23 df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46 df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69 The updated parameters are given by: w'_1 = w_1 - 1. * df / dw_1 = 2 - 1. * 23 = -21 w'_2 = w_2 - 1. * df / dw_2 = 3 - 1. * 46 = -43 w'_3 = w_3 - 1. * df / dw_3 = 5 - 1. * 69 = -64 """ train_inputs = torch.tensor([[1., 2., 3.]]) train_loss_1 = 0.5 * (model(train_inputs)**2) params_1 = gradient_update_parameters(model, train_loss_1, params=None, step_size=1., first_order=False) assert train_loss_1.item() == 264.5 assert list(params_1.keys()) == ['weight'] assert torch.all( params_1['weight'].data == torch.tensor([[-21., -43., -64.]])) """ The new loss function is defined as g(w') = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2 with w' = [-21, -43, -64]. The gradient of g with respect to w' is: dg / dw'_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -299 dg / dw'_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -598 dg / dw'_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -897 The updated parameters are given by: w''_1 = w'_1 - 1. * dg / dw'_1 = -21 - 1. * -299 = 278 w''_2 = w'_2 - 1. * dg / dw'_2 = -43 - 1. * -598 = 555 w''_3 = w'_3 - 1. * dg / dw'_3 = -64 - 1. * -897 = 833 """ train_loss_2 = 0.5 * (model(train_inputs, params=params_1)**2) params_2 = gradient_update_parameters(model, train_loss_2, params=params_1, step_size=1., first_order=False) assert train_loss_2.item() == 44700.5 assert list(params_2.keys()) == ['weight'] assert torch.all( params_2['weight'].data == torch.tensor([[278., 555., 833.]])) """ The new loss function is defined as h(w'') = 0.5 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) ** 2 with w'' = [278, 555, 833]. The gradient of h with respect to w'' is: dh / dw''_1 = 1 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) = 3887 dh / dw''_2 = 2 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) = 7774 dh / dw''_3 = 3 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) = 11661 The updated parameters are given by: w'''_1 = w''_1 - 1. * dh / dw''_1 = 278 - 1. * 3887 = -3609 w'''_2 = w''_2 - 1. * dh / dw''_2 = 555 - 1. * 7774 = -7219 w'''_3 = w''_3 - 1. * dh / dw''_3 = 833 - 1. * 11661 = -10828 """ train_loss_3 = 0.5 * (model(train_inputs, params=params_2)**2) params_3 = gradient_update_parameters(model, train_loss_3, params=params_2, step_size=1., first_order=False) assert train_loss_3.item() == 7554384.5 assert list(params_3.keys()) == ['weight'] assert torch.all( params_3['weight'].data == torch.tensor([[-3609., -7219., -10828.]])) """ The new loss function is defined as l(w) = 4 * w'''_1 + 5 * w'''_2 + 6 * w'''_3 with w = [2, 3, 5] and w''' = [-3609, -7219, -10828]. The gradient of l with respect to w is: dl / dw_1 = 4 * dw'''_1 / dw_1 + 5 * dw'''_2 / dw_1 + 6 * dw'''_3 / dw_1 = ... = -5020 dl / dw_2 = 4 * dw'''_1 / dw_2 + 5 * dw'''_2 / dw_2 + 6 * dw'''_3 / dw_2 = ... = -10043 dl / dw_3 = 4 * dw'''_1 / dw_3 + 5 * dw'''_2 / dw_3 + 6 * dw'''_3 / dw_3 = ... = -15066 """ test_inputs = torch.tensor([[4., 5., 6.]]) test_loss = model(test_inputs, params=params_3) grads = torch.autograd.grad(test_loss, model.parameters()) assert test_loss.item() == -115499. assert len(grads) == 1 assert torch.all( grads[0].data == torch.tensor([[-5020., -10043., -15066.]]))
def train(inputs=[], adapt_inputs=[], exp_name="maml", batch_size=16, num_workers=1, use_cuda=False, num_batches=100, step_size=0.4, shots=1000, test_shots=200, save_per=-1, eval_per=1, learning_rate=0.01, first_order=False, save_dir=".", date="210101", seed=None, logger_kwargs={}, exp_params=DotMap()): device = torch.device( 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu') def get_loss(output, targets): return -1 * output.log_prob(targets).sum(dim=1).mean() def get_adapted_params(model, test_batch): test_inputs, test_targets = test_batch['train'] test_in, test_target = test_inputs[0], test_targets[0] test_out = model(test_in) inner_loss = get_loss(test_out, test_target) model.zero_grad() with torch.no_grad(): params = gradient_update_parameters(model, inner_loss, step_size=step_size, first_order=first_order) test_out = model(test_in, params=params) outer_loss = get_loss(test_out, test_target) return params, outer_loss.item() from torch.utils.tensorboard import SummaryWriter import datetime env_name = "BedBathingBaxterHuman-v0217_0-v1" env = gym.make('assistive_gym:' + env_name) dataset = behaviour(inputs, shots=shots, test_shots=test_shots) dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) adapt_datasets, adapt_loaders = {}, {} for key, adapt_dir in adapt_inputs.items(): adapt_datasets[key] = behaviour([adapt_dir], shots=shots, test_shots=test_shots) adapt_loaders[key] = BatchMetaDataLoader(adapt_datasets[key], batch_size=1, shuffle=True, num_workers=num_workers) model = PolicyNetwork(env.observation_space_human.shape[0], env.action_space_human.shape[0]) model.to(device=device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) now = datetime.datetime.now() hour = '{:02d}'.format(now.hour) minute = '{:02d}'.format(now.minute) second = '{:02d}'.format(now.second) timestamp = '{}-{}-{}'.format(hour, minute, second) # new_data/date/MAML_assistive_gym_sz0-1_lr0-01_s50_ts200/MAML_assistive_gym_sz0-1_lr0-01_s50_ts200_s0 output_dir = logger_kwargs['output_dir'] log_folder = os.path.join(save_dir, date, "runs", f"{exp_name}_{timestamp}") print(f"Saving logs to {log_folder}") os.makedirs(log_folder, exist_ok=True) writer = SummaryWriter(log_dir=log_folder, comment=f"{exp_name}") rllib_saver = RLLibSaver() adapt_losses = {key: [] for key in adapt_loaders.keys()} # Training loop with tqdm(dataloader, total=num_batches, disable=True) as pbar: for batch_idx, batches in enumerate( zip(pbar, *list(adapt_loaders.values()))): model.zero_grad() main_batch = batches[0] train_inputs, train_targets = main_batch['train'] train_inputs = train_inputs.to(device=device).float() train_targets = train_targets.to(device=device).float() test_inputs, test_targets = main_batch['test'] test_inputs = test_inputs.to(device=device).float() test_targets = test_targets.to(device=device).float() outer_loss = torch.tensor(0., device=device) loss = torch.tensor(0., device=device) for task_idx, (train_input, train_target, test_input, test_target) in enumerate( zip(train_inputs, train_targets, test_inputs, test_targets)): train_output = model(train_input) inner_loss = get_loss(train_output, train_target) model.zero_grad() params = gradient_update_parameters(model, inner_loss, step_size=step_size, first_order=first_order) test_output = model(test_input, params=params) outer_loss += get_loss(test_output, test_target) with torch.no_grad(): loss += get_loss(test_output, test_target) outer_loss.div_(batch_size) loss.div_(batch_size) outer_loss.backward() meta_optimizer.step() # Report progress pbar.set_postfix(loss='{0:.4f}'.format(loss.item())) print(f"Iter {batch_idx} train loss: {loss.item():.3f}") writer.add_scalar(f"train/maml_loss", loss.item(), batch_idx) writer.flush() # Eval & Save model do_eval = batch_idx % eval_per == 0 do_save = output_dir is not None and save_per > 0 and ( (batch_idx % save_per == 0) or (batch_idx == num_batches)) if do_eval: all_pre_params = [] all_post_params = [] all_inputs = [] for adapt_key, adapt_batch in zip(list(adapt_loaders.keys()), batches[1:]): pre_params = OrderedDict(model.meta_named_parameters()) post_params, outer_loss = get_adapted_params( model, adapt_batch) all_pre_params.append(pre_params) all_post_params.append(post_params) all_inputs.append( adapt_batch['train'][0][0]) # inputs, idx=1 print(f"Save inner loss {adapt_key}: {outer_loss:.04f}") if do_save: rllib_saver.save(params=post_params, save_path=output_dir, key=adapt_key, iteration=batch_idx, exp_params=exp_params) adapt_losses[adapt_key].append(outer_loss) adapt_keys = list(adapt_loaders.keys()) if len(adapt_keys) > 1: _, _, fig = cluster_activation(model, all_inputs, all_pre_params, adapt_keys, "fc3") writer.add_figure(f"train/fc3_before", fig, batch_idx) writer.flush() _, _, fig = cluster_activation(model, all_inputs, all_post_params, adapt_keys, "fc3") writer.add_figure(f"train/fc3_after", fig, batch_idx) writer.flush() with open(os.path.join(output_dir, "adapt_losses.txt"), "w+") as f: yaml.dump(adapt_losses, f) if batch_idx >= num_batches: break
def train(ml_custom, policies, args): # Prepare to log info writer = SummaryWriter() # Define model model = MIL() model.to(device=args.device) # load_model(model, "./models/mil_499.th") model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) print("Start training") for batch in range(args.num_batches): outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for name in np.random.choice(list(ml_custom.keys()), 3, replace=False): env_cls = ml_custom[name] print("Task: %s" % name) policy = policies[name] all_tasks = [ task for task in ml45.train_tasks if task.env_name == name ] # Adapt in support task env = env_cls() support_task = random.choice(all_tasks[:25]) env.set_task(support_task) batches_imgs, batches_configs, batches_actions = get_data( env, policy, args) inner_loss = torch.tensor(0., device=args.device) number_batches = len(batches_imgs) while (len(batches_imgs) > 0): pred_actions = model( batches_imgs.pop().to(device=args.device), batches_configs.pop().to(device=args.device)) inner_loss += F.mse_loss( pred_actions, batches_actions.pop().to(device=args.device)) inner_loss.div_(number_batches) model.zero_grad() params = gradient_update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order) # Evaluate in query task env = env_cls() query_task = random.choice(all_tasks[25:]) env.set_task(support_task) batches_imgs, batches_configs, batches_actions = get_data( env, policy, args) aux_loss = torch.tensor(0., device=args.device) aux_accuracy = torch.tensor(0., device=args.device) number_batches = len(batches_imgs) while (len(batches_imgs) > 0): pred_actions = model( batches_imgs.pop().to(device=args.device), batches_configs.pop().to(device=args.device)) batch_actions = batches_actions.pop().to(device=args.device) aux_loss += F.mse_loss(pred_actions, batch_actions) with torch.no_grad(): aux_accuracy += get_accuracy(pred_actions, batch_actions) aux_loss.div_(number_batches) aux_accuracy.div_(number_batches) outer_loss += aux_loss accuracy += aux_accuracy outer_loss.div_(3) accuracy.div_(3) meta_optimizer.zero_grad() outer_loss.backward() meta_optimizer.step() #Log info writer.add_scalar('meta_train/loss', outer_loss.item(), batch) writer.add_scalar('meta_train/accuracy', accuracy.item(), batch) print("batch: %d loss: %.4f accuracy: %.4f" % (batch, outer_loss.item(), accuracy.item())) # Save model save_model(model, args.output_folder, 'mil_%d.th' % batch)
def train(args): dataset = clinic(shots=args.num_shots, ways=args.num_ways, shuffle=True, test_shots=15, meta_train=True, download=args.download) dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) model = ConvolutionalNeuralNetwork(1, args.num_ways, hidden_size=args.hidden_size) model.to(device=args.device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop with tqdm(dataloader, total=args.num_batches) as pbar: for batch_idx, batch in enumerate(pbar): model.zero_grad() train_inputs, train_targets = batch['train'] train_inputs = train_inputs.to(device=args.device) train_targets = train_targets.to(device=args.device) test_inputs, test_targets = batch['test'] test_inputs = test_inputs.to(device=args.device) test_targets = test_targets.to(device=args.device) outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for task_idx, (train_input, train_target, test_input, test_target) in enumerate(zip(train_inputs, train_targets, test_inputs, test_targets)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) model.zero_grad() params = gradient_update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) outer_loss.div_(args.batch_size) accuracy.div_(args.batch_size) outer_loss.backward() meta_optimizer.step() pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) if batch_idx >= args.num_batches: break # Save model if args.output_folder is not None: filename = os.path.join(args.output_folder, 'maml_omniglot_' '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways)) with open(filename, 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f)
support_inputs, support_targets = [_.cuda(non_blocking=True) for _ in train_batch['train']] if args.use_cuda else [_ for _ in train_batch['train']] query_inputs, query_targets = [_.cuda(non_blocking=True) for _ in train_batch['test']] if args.use_cuda else [_ for _ in train_batch['test']] train_loss = torch.tensor(0., device=support_inputs.device) train_acc = torch.tensor(0., device=support_inputs.device) for _ , (support_input, support_target, query_input, query_target) in enumerate(zip(support_inputs, support_targets, query_inputs, query_targets)): #meta inner loop support_logit = model(support_input) train_inner_loss = F.cross_entropy(support_logit, support_target) model.zero_grad() params = gradient_update_parameters(model, train_inner_loss, step_size=args.step_size, first_order=args.first_order) #meta outer loop if train_batch_i==int(args.train_tasks/args.batch_tasks)-1: teacher_model.eval() teacher_query_logit = teacher_model(query_input) query_logit = model(query_input, params=params) train_loss += get_loss(args, query_logit, query_target, teacher_query_logit) else: query_logit = model(query_input, params=params) train_loss += F.cross_entropy(query_logit, query_target) with torch.no_grad(): train_acc += count_acc(query_logit, query_target)
def train(args): perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis") total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad") trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock) TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15) TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers) model = MLP(out_features = args.num_ways) model.to(device=args.device) model.train() meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Training loop with tqdm(TrainLoader, total=args.num_batches) as pbar: for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar): model.zero_grad() inputs_support = inputs_support.to(device=args.device) target_support = target_support.to(device=args.device) inputs_query = inputs_query.to(device=args.device) target_query = target_query.to(device=args.device) outer_loss = torch.tensor(0., device=args.device) accuracy = torch.tensor(0., device=args.device) for task_idx, (train_input, train_target, test_input, test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)): train_logit = model(train_input) inner_loss = F.cross_entropy(train_logit, train_target) model.zero_grad() params = gradient_update_parameters(model, inner_loss, step_size=args.step_size, first_order=args.first_order) test_logit = model(test_input, params=params) outer_loss += F.cross_entropy(test_logit, test_target) with torch.no_grad(): accuracy += get_accuracy(test_logit, test_target) outer_loss.div_(args.batch_size_train) accuracy.div_(args.batch_size_train) outer_loss.backward() meta_optimizer.step() pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item())) if batch_idx >= args.num_batches or accuracy.item() > 0.95: break # Save model if args.output_folder is not None: filename = os.path.join(args.output_folder, 'maml_omniglot_' '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways)) with open(filename, 'wb') as f: state_dict = model.state_dict() torch.save(state_dict, f) # start test test_support, test_query, test_target_support, test_target_query \ = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset)) test_query = torch.from_numpy(test_query).to(device=args.device) test_target_query = torch.from_numpy(test_target_query).to(device=args.device) TrainSet = perturbdataloader_test(test_support, test_target_support) TrainLoader = DataLoader(TrainSet, args.batch_size_test) meta_optimizer.zero_grad() inner_losses = [] accuracy_test = [] for epoch in range(args.num_epoch): model.to(device=args.device) model.train() for _, (inputs_support,target_support) in enumerate(TrainLoader): inputs_support = inputs_support.to(device=args.device) target_support = target_support.to(device=args.device) train_logit = model(inputs_support) loss = F.cross_entropy(train_logit, target_support) inner_losses.append(loss) loss.backward() meta_optimizer.step() meta_optimizer.zero_grad() test_logit = model(test_query) with torch.no_grad(): accuracy = get_accuracy(test_logit, test_target_query) accuracy_test.append(accuracy) if (epoch + 1) % 3 == 0: print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))