Esempio n. 1
0
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item()**2
    total_norm = total_norm**(1. / 2)
    return total_norm


losses = np.zeros((2, num_training, num_transfers, num_episodes))

for k in tnrange(num_training):
    pi_A_1 = np.random.dirichlet(np.ones(N))
    pi_B_A = np.random.dirichlet(np.ones(N), size=N)
    for j in tnrange(num_transfers, leave=False):
        model.set_ground_truth(pi_A_1, pi_B_A)
        pi_A_2 = np.random.dirichlet(np.ones(N))
        all_x_transfer = torch.from_numpy(
            generate_data_categorical(batch_size * num_episodes, pi_A_2,
                                      pi_B_A))
        for i in range(num_episodes):
            x_transfer = all_x_transfer[:(batch_size * (i + 1))]

            val_loss_A_B = -torch.mean(model.model_A_B(x_transfer))
            val_loss_B_A = -torch.mean(model.model_B_A(x_transfer))

            grad_norm_A_B = get_gradient_norm(model, val_loss_A_B)
            grad_norm_B_A = get_gradient_norm(model, val_loss_B_A)

            record = grad_norm_A_B - grad_norm_B_A

            losses[:, k, j, i] = [record, 0]

flat_losses = -losses.reshape((2, -1, num_episodes))
losses_25, losses_50, losses_75 = np.percentile(flat_losses, (25, 50, 75),
train_batch_size = 1000
transfer_batch_size = 10

alphas = np.zeros((num_runs, num_training, num_transfer))
accs = np.zeros((num_runs, num_training, num_transfer))
times = np.zeros((num_runs, num_training, num_transfer))

for j in tnrange(num_runs):
    model.w.data.zero_()
    for i in tnrange(num_training, leave=False):
        # Step 1: Sample a joint distribution before intervention
        pi_A_1 = np.random.dirichlet(np.ones(N))
        pi_B_A = np.random.dirichlet(np.ones(N), size=N)

        model.set_ground_truth(pi_A_1, pi_B_A)
        x_train_original = torch.from_numpy(generate_data_categorical(
            train_batch_size, pi_A_1, pi_B_A))
        with torch.no_grad():
            original_loss_A_B = -torch.mean(model.model_A_B(x_train_original))
            original_loss_B_A = -torch.mean(model.model_B_A(x_train_original))

        cum_AtoB = 0
        cum_BtoA = 0
        transfers = tnrange(num_transfer, leave=False)
        for k in transfers:
            # Step 2: Train the modules on the training distribution
            model.set_ground_truth(pi_A_1, pi_B_A)

            # Step 3: Sample a joint distribution after intervention
            pi_A_2 = np.random.dirichlet(np.ones(N))

            start = time.time()
        transfers = tnrange(num_transfer, leave=False)
        for k in transfers:
            # Step 2: Train the modules on the training distribution
            model.set_ground_truth(pi_A_1, pi_B_A)

            # Step 3: Sample a joint distribution after intervention
            pi_A_2 = np.random.dirichlet(np.ones(N))

            start = time.time()
            # Step 4: Do k steps of gradient descent for adaptation on the
            # distribution after intervention
            model.zero_grad()
            loss = torch.tensor(0., dtype=torch.float64)
            for _ in range(num_gradient_steps):
                x_train = torch.from_numpy(
                    generate_data_categorical(transfer_batch_size, pi_A_2,
                                              pi_B_A))
                loss += -torch.mean(model(x_train))
                optimizer.zero_grad()
                inner_loss_A_B = -torch.mean(model.model_A_B(x_train))
                inner_loss_B_A = -torch.mean(model.model_B_A(x_train))
                inner_loss = inner_loss_A_B + inner_loss_B_A
                inner_loss.backward()
                optimizer.step()

            # Step 5: Update the structural parameter alpha
            meta_optimizer.zero_grad()
            loss.backward()
            meta_optimizer.step()
            end = time.time()

            # Log the values of alpha
optimizer = torch.optim.SGD(model.modules_parameters(), lr=1.)

losses = np.zeros((2, num_training, num_transfers, num_episodes))

for k in tnrange(num_training):
    pi_A_1 = np.random.dirichlet(np.ones(N))
    pi_B_A = np.random.dirichlet(np.ones(N), size=N)
    for j in tnrange(num_transfers, leave=False):
        model.set_ground_truth(pi_A_1, pi_B_A)

        # train P(A)
        a_model = model.model_A_B.p_A
        a_optimizer = torch.optim.Adam(a_model.parameters(), lr=0.1)
        for i in range(num_episodes):
            x_train = torch.from_numpy(generate_data_categorical(batch_size, pi_A_1, pi_B_A))
            inputs_A, inputs_B = torch.split(x_train, 1, dim=1)
            a_model.zero_grad()
            a_loss = -torch.mean(a_model(inputs_A))
            a_loss.backward()
            a_optimizer.step()

        pi_A_2 = np.random.dirichlet(np.ones(N))
        x_val = torch.from_numpy(generate_data_categorical(num_test, pi_A_2, pi_B_A))
        for i in range(num_episodes):
            x_transfer = torch.from_numpy(generate_data_categorical(batch_size, pi_A_2, pi_B_A))
            model.zero_grad()
            loss_A_B = -torch.mean(model.model_A_B(x_transfer))
            loss_B_A = -torch.mean(model.model_B_A(x_transfer))
            loss = loss_A_B + loss_B_A