# Step 3: Sample a joint distribution after intervention pi_A_2 = np.random.dirichlet(np.ones(N)) start = time.time() # Step 4: Do k steps of gradient descent for adaptation on the # distribution after intervention model.zero_grad() loss = torch.tensor(0., dtype=torch.float64) for _ in range(num_gradient_steps): x_train = torch.from_numpy( generate_data_categorical(transfer_batch_size, pi_A_2, pi_B_A)) loss += -torch.mean(model(x_train)) optimizer.zero_grad() inner_loss_A_B = -torch.mean(model.model_A_B(x_train)) inner_loss_B_A = -torch.mean(model.model_B_A(x_train)) inner_loss = inner_loss_A_B + inner_loss_B_A inner_loss.backward() optimizer.step() # Step 5: Update the structural parameter alpha meta_optimizer.zero_grad() loss.backward() meta_optimizer.step() end = time.time() # Log the values of alpha with torch.no_grad(): alpha = torch.sigmoid(model.w).item() if alpha > 0.5:
a_optimizer = torch.optim.Adam(a_model.parameters(), lr=0.1) for i in range(num_episodes): x_train = torch.from_numpy(generate_data_categorical(batch_size, pi_A_1, pi_B_A)) inputs_A, inputs_B = torch.split(x_train, 1, dim=1) a_model.zero_grad() a_loss = -torch.mean(a_model(inputs_A)) a_loss.backward() a_optimizer.step() pi_A_2 = np.random.dirichlet(np.ones(N)) x_val = torch.from_numpy(generate_data_categorical(num_test, pi_A_2, pi_B_A)) for i in range(num_episodes): x_transfer = torch.from_numpy(generate_data_categorical(batch_size, pi_A_2, pi_B_A)) model.zero_grad() loss_A_B = -torch.mean(model.model_A_B(x_transfer)) loss_B_A = -torch.mean(model.model_B_A(x_transfer)) loss = loss_A_B + loss_B_A with torch.no_grad(): val_loss_A_B = -torch.mean(model.model_A_B(x_val)) val_loss_B_A = -torch.mean(model.model_B_A(x_val)) losses[:, k, j, i] = [val_loss_A_B.item(), val_loss_B_A.item()] loss.backward() optimizer.step() flat_losses = -losses.reshape((2, -1, num_episodes)) losses_25, losses_50, losses_75 = np.percentile(flat_losses, (25, 50, 75), axis=1) plt.figure(figsize=(18, 12))