def test_update_parameters(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of the function f with respect to w, and evaluated
    at w = [2, 3, 5], is:
        df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23
        df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46
        df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69

    The updated parameter w' is then given by one step of gradient descent,
    with step size 0.5:
        w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 * 23 = -9.5
        w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 46 = -20
        w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 68 = -29.5
    """
    train_inputs = torch.tensor([[1., 2., 3.]])
    train_loss = 0.5 * (model(train_inputs)**2)

    params = gradient_update_parameters(model,
                                        train_loss,
                                        params=None,
                                        step_size=0.5,
                                        first_order=False)

    assert train_loss.item() == 264.5
    assert list(params.keys()) == ['weight']
    assert torch.all(
        params['weight'].data == torch.tensor([[-9.5, -20., -29.5]]))
    """
    The new loss function (still with respect to the weights of the model w) is
    defined as:
        g(w) = 0.5 * (4 * w'_1 + 5 * w'_2 + 6 * w'_3) ** 2
             = 0.5 * (4 * (w_1 - 0.5 * df / dw_1)
                    + 5 * (w_2 - 0.5 * df / dw_2)
                    + 6 * (w_3 - 0.5 * df / dw_3)) ** 2
             = 0.5 * (4 * (w_1 - 0.5 * 1 * (1 * w_1 + 2 * w_2 + 3 * w_3))
                    + 5 * (w_2 - 0.5 * 2 * (1 * w_1 + 2 * w_2 + 3 * w_3))
                    + 6 * (w_3 - 0.5 * 3 * (1 * w_1 + 2 * w_2 + 3 * w_3))) ** 2
             = 0.5 * ((4 - 4 * 0.5 - 5 * 1.0 - 6 * 1.5) * w_1
                    + (5 - 4 * 1.0 - 5 * 2.0 - 6 * 3.0) * w_2
                    + (6 - 4 * 1.5 - 5 * 3.0 - 6 * 4.5) * w_3) ** 2
             = 0.5 * (-12 * w_1 - 27 * w_2 - 42 * w_3) ** 2

    Therefore the gradient of the function g with respect to w (and evaluated
    at w = [2, 3, 5]) is:
        dg / dw_1 = -12 * (-12 * w_1 - 27 * w_2 - 42 * w_3) =  3780
        dg / dw_2 = -27 * (-12 * w_1 - 27 * w_2 - 42 * w_3) =  8505
        dg / dw_3 = -42 * (-12 * w_1 - 27 * w_2 - 42 * w_3) = 13230
    """
    test_inputs = torch.tensor([[4., 5., 6.]])
    test_loss = 0.5 * (model(test_inputs, params=params)**2)

    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == 49612.5
    assert len(grads) == 1
    assert torch.all(grads[0].data == torch.tensor([[3780., 8505., 13230.]]))
Exemple #2
0
    def update_policy(self, trajectorys):
        # self.optimizer_p.zero_grad()
        self.policy_net.zero_grad()

        states = trajectorys[self.id].get_state()
        actions = trajectorys[self.id].get_action()
        rewards_env = trajectorys[self.id].get_reward_env()
        rewards_given = trajectorys[self.id].get_reward_from()
        loss_policy = []
        loss_critic = []
        loss_entropy = []

        R = 0
        returns_env = []
        returns_given = []

        # Compute policy loss
        # Get the V value of timestep from critic
        logits, V_s = self.policy_net(states)
        prob = F.softmax(logits, dim=-1)
        log_prob = F.log_softmax(logits, dim=-1)
        V_s = V_s.view(-1)

        for r in rewards_env[::-1]:
            R = r + gamma * R
            returns_env.insert(0, R)
        for r in rewards_given[::-1]:
            R = r + gamma * R
            returns_given.insert(0, R)

        returns_env = torch.Tensor(returns_env).detach()
        returns_given = torch.cat(returns_given, dim=0)
        returns = returns_env + returns_given
        # returns = returns_env

        Q_s_a = returns
        A_s_a = Q_s_a - V_s

        # compute policy loss
        loss_entropy_p = - log_prob * prob
        loss_entropy_p = loss_entropy_p.mean()
        loss_entropy.append(loss_entropy_p)

        log_prob_act = torch.stack([log_prob[i][actions[i]] for i in range(len(actions))], dim=0)
        loss_policy_p = - torch.dot(A_s_a, log_prob_act).view(1) / len(prob)
        loss_policy.append(loss_policy_p)

        # Compute critic loss
        # loss_critic_p = (returns - V_s).pow(2).mean()
        loss_critic_p = A_s_a.pow(2).mean()
        loss_critic.append(loss_critic_p)

        loss_policy = torch.stack(loss_policy).mean()
        loss_critic = torch.stack(loss_critic).mean()
        loss_entropy = torch.stack(loss_entropy).mean()
        loss = loss_policy + 0.5 * loss_critic + 0.01 * loss_entropy

        # loss.backward(retain_graph=True)
        # self.optimizer_p.step()
        self.new_params = gradient_update_parameters(self.policy_net, loss, step_size=step_size)
Exemple #3
0
def test_policy_update(config):
    agents = []
    for i in range(config.env.n_agents):
        agents.append(Actor(i, 7, config.env.n_agents))

    input = torch.Tensor([1, 1, 1, 1, 1, 1, 1])

    agent0 = agents[0]
    agent1 = agents[1]

    output0 = agent0.policy_net(input)
    output1 = agent1.policy_net(input)
    loss0 = 1 - output0.sum()
    loss1 = 2 - output1.sum()

    print(loss0)
    print(loss1)

    agent0.new_params = gradient_update_parameters(agent0.policy_net, loss0, step_size=0.5)
    agent1.new_params = gradient_update_parameters(agent1.policy_net, loss1, step_size=0.5)

    output0 = agent0.policy_net(input, agent0.new_params)
    output1 = agent1.policy_net(input, agent1.new_params)
    loss0 = 1 - output0.sum()
    loss1 = 2 - output1.sum()

    print(loss0)
    print(loss1)

    agent0.update_to_new_params()
    agent1.update_to_new_params()

    output0 = agent0.policy_net(input)
    output1 = agent1.policy_net(input)
    loss0 = 1 - output0.sum()
    loss1 = 2 - output1.sum()

    print(loss0)
    print(loss1)
def test_update_parameters_first_order(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (4 * w_1 + 5 * w_2 + 6 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of the function f with respect to w, and evaluated
    at w = [2, 3, 5] is:
        df / dw_1 = 4 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 212
        df / dw_2 = 5 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 265
        df / dw_3 = 6 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 318

    The updated parameter w' is then given by one step of gradient descent,
    with step size 0.5:
        w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 * 212 = -104
        w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 265 = -129.5
        w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 318 = -154
    """
    train_inputs = torch.tensor([[4., 5., 6.]])
    train_loss = 0.5 * (model(train_inputs)**2)

    params = gradient_update_parameters(model,
                                        train_loss,
                                        params=None,
                                        step_size=0.5,
                                        first_order=True)

    assert train_loss.item() == 1404.5
    assert list(params.keys()) == ['weight']
    assert torch.all(
        params['weight'].data == torch.tensor([[-104., -129.5, -154.]]))
    """
    The new loss function (still with respect to the weights of the model w) is
    defined as:
        g(w) = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2

    Since we computed w' with the first order approximation, the gradient of the
    function g with respect to w, and evaluated at w = [2, 3, 5], is:
        dg / dw_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) =  -825
        dg / dw_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -1650
        dg / dw_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -2475
    """
    test_inputs = torch.tensor([[1., 2., 3.]])
    test_loss = 0.5 * (model(test_inputs, params=params)**2)

    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == 340312.5
    assert len(grads) == 1
    assert torch.all(grads[0].data == torch.tensor([[-825., -1650., -2475.]]))
Exemple #5
0
 def get_adapted_params(model, test_batch):
     test_inputs, test_targets = test_batch['train']
     test_in, test_target = test_inputs[0], test_targets[0]
     test_out = model(test_in)
     inner_loss = get_loss(test_out, test_target)
     model.zero_grad()
     with torch.no_grad():
         params = gradient_update_parameters(model,
                                             inner_loss,
                                             step_size=step_size,
                                             first_order=first_order)
         test_out = model(test_in, params=params)
         outer_loss = get_loss(test_out, test_target)
     return params, outer_loss.item()
Exemple #6
0
def test_dataparallel_params_maml(model):
    device = torch.device('cuda:0')
    model = DataParallel(model)
    model.to(device=device)

    train_inputs = torch.rand(5, 2).to(device=device)
    train_outputs = model(train_inputs)

    inner_loss = train_outputs.sum()  # Dummy loss
    params = gradient_update_parameters(model, inner_loss)

    test_inputs = torch.rand(5, 2).to(device=device)
    test_outputs = model(test_inputs, params=params)

    assert test_outputs.shape == (5, 1)
    assert test_outputs.device == device

    outer_loss = test_outputs.sum()  # Dummy loss
    outer_loss.backward()
Exemple #7
0
def meta_train(args, metaDataloader):
    model = RegressionNeuralNetwork(args['in_channels'],
                                    hidden1_size=args['hidden1_size'],
                                    hidden2_size=args['hidden2_size'])
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=args['beta'])
    loss_record = []
    # training loop
    for it_outer in range(args['num_it_outer']):
        model.zero_grad()

        train_dataloader = metaDataloader['train']

        test_dataloader = metaDataloader['test']

        outer_loss = torch.tensor(0., dtype=torch.float)
        for task in train_dataloader:
            iterator = iter(train_dataloader[task])
            train_sample = iterator.next()
            # get true h value
            # h_value = torch.tensor(train_sample[:,-1], dtype=torch.float)
            h_value = train_sample[:,
                                   -1].clone().detach().to(dtype=torch.float)
            # get input
            # input_value = torch.tensor(train_sample[:,:-1], dtype=torch.float)
            input_value = train_sample[:, :-1].clone().detach().to(
                dtype=torch.float)
            #
            train_h_value = model(input_value)
            inner_loss = F.mse_loss(train_h_value.view(-1, 1),
                                    h_value.view(-1, 1))

            model.zero_grad()
            # print('It {}, task {}, Start updating parameters'.format(it_outer, task))
            params = gradient_update_parameters(
                model,
                inner_loss,
                step_size=args['alpha'],
                first_order=args['first_order'])
            # adaptation
            # get test sample
            test_iterator = iter(test_dataloader[task])
            test_sample = test_iterator.next()
            # h_value2 = torch.tensor(test_sample[:,-1], dtype=torch.float)
            h_value2 = test_sample[:,
                                   -1].clone().detach().to(dtype=torch.float)
            # test_input_value = torch.tensor(test_sample[:,:-1], dtype=torch.float)
            test_input_value = test_sample[:, :-1].clone().detach().to(
                dtype=torch.float)
            test_h_value = model(test_input_value, params=params)

            outer_loss += F.mse_loss(test_h_value.view(-1, 1),
                                     h_value2.view(-1, 1))

        outer_loss.div_(args['num_tasks'])

        outer_loss.backward()
        meta_optimizer.step()

        loss_record.append(outer_loss.detach())
        if it_outer % 50 == 0:
            print('It {}, outer traning loss: {}'.format(it_outer, outer_loss))
            # print the loss plot
    plt.plot(loss_record)
    plt.title('Outer Training Loss (MSE Loss) in MAML')
    plt.xlabel('Iteration number')
    plt.show()

    # save model
    if args['output_model'] is not None:
        with open(args['output_model'], 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemple #8
0
def train(args):
    logger.warning(
        'This script is an example to showcase the MetaModule and '
        'data-loading features of Torchmeta, and as such has been '
        'very lightly tested. For a better tested implementation of '
        'Model-Agnostic Meta-Learning (MAML) using Torchmeta with '
        'more features (including multi-step adaptation and '
        'different datasets), please check `https://github.com/'
        'tristandeleu/pytorch-maml`.')

    dataset = omniglot(args.folder,
                       shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(
                    model,
                    inner_loss,
                    step_size=args.step_size,
                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(
            args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemple #9
0
def train():
    transform = transforms.Compose(
        [transforms.Resize(84), transforms.ToTensor()])
    dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=5,
                                      num_test_per_class=5)
    dataset = MiniImagenet('',
                           transform=transform,
                           num_classes_per_task=5,
                           target_transform=Categorical(num_classes=5),
                           meta_split="train",
                           dataset_transform=dataset_transform)

    dataloader = BatchMetaDataLoader(dataset, batch_size=1, shuffle=True)

    model = ModelConvMiniImagenet(5)
    model.to(device='cuda')
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    accuracy_l = list()

    with tqdm(dataloader, total=1000) as pbar:
        for batch_idx, batch in enumerate(pbar):

            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device='cuda')
            train_targets = train_targets.to(device='cuda')

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device='cuda')
            test_targets = test_targets.to(device='cuda')

            outer_loss = torch.tensor(0., device='cuda')
            accuracy = torch.tensor(0., device='cuda')
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model, inner_loss)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)
            outer_loss.div_(1)
            accuracy.div_(1)

            outer_loss.backward()
            meta_optimizer.step()
            accuracy_l.append(accuracy.item())
            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if (batch_idx >= 1000):
                break

    plt.plot(accuracy_l)
    plt.show()
def test_multiple_update_parameters(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of f with respect to w is:
        df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23
        df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46
        df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69

    The updated parameters are given by:
        w'_1 = w_1 - 1. * df / dw_1 = 2 - 1. * 23 = -21
        w'_2 = w_2 - 1. * df / dw_2 = 3 - 1. * 46 = -43
        w'_3 = w_3 - 1. * df / dw_3 = 5 - 1. * 69 = -64
    """
    train_inputs = torch.tensor([[1., 2., 3.]])

    train_loss_1 = 0.5 * (model(train_inputs)**2)
    params_1 = gradient_update_parameters(model,
                                          train_loss_1,
                                          params=None,
                                          step_size=1.,
                                          first_order=False)

    assert train_loss_1.item() == 264.5
    assert list(params_1.keys()) == ['weight']
    assert torch.all(
        params_1['weight'].data == torch.tensor([[-21., -43., -64.]]))
    """
    The new loss function is defined as
        g(w') = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2
    with w' = [-21, -43, -64].

    The gradient of g with respect to w' is:
        dg / dw'_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -299
        dg / dw'_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -598
        dg / dw'_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -897

    The updated parameters are given by:
        w''_1 = w'_1 - 1. * dg / dw'_1 = -21 - 1. * -299 = 278
        w''_2 = w'_2 - 1. * dg / dw'_2 = -43 - 1. * -598 = 555
        w''_3 = w'_3 - 1. * dg / dw'_3 = -64 - 1. * -897 = 833
    """
    train_loss_2 = 0.5 * (model(train_inputs, params=params_1)**2)
    params_2 = gradient_update_parameters(model,
                                          train_loss_2,
                                          params=params_1,
                                          step_size=1.,
                                          first_order=False)

    assert train_loss_2.item() == 44700.5
    assert list(params_2.keys()) == ['weight']
    assert torch.all(
        params_2['weight'].data == torch.tensor([[278., 555., 833.]]))
    """
    The new loss function is defined as
        h(w'') = 0.5 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) ** 2
    with w'' = [278, 555, 833].

    The gradient of h with respect to w'' is:
        dh / dw''_1 = 1 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) =  3887
        dh / dw''_2 = 2 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) =  7774
        dh / dw''_3 = 3 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) = 11661

    The updated parameters are given by:
        w'''_1 = w''_1 - 1. * dh / dw''_1 = 278 - 1. *  3887 =  -3609
        w'''_2 = w''_2 - 1. * dh / dw''_2 = 555 - 1. *  7774 =  -7219
        w'''_3 = w''_3 - 1. * dh / dw''_3 = 833 - 1. * 11661 = -10828
    """
    train_loss_3 = 0.5 * (model(train_inputs, params=params_2)**2)
    params_3 = gradient_update_parameters(model,
                                          train_loss_3,
                                          params=params_2,
                                          step_size=1.,
                                          first_order=False)

    assert train_loss_3.item() == 7554384.5
    assert list(params_3.keys()) == ['weight']
    assert torch.all(
        params_3['weight'].data == torch.tensor([[-3609., -7219., -10828.]]))
    """
    The new loss function is defined as
        l(w) = 4 * w'''_1 + 5 * w'''_2 + 6 * w'''_3
    with w = [2, 3, 5] and w''' = [-3609, -7219, -10828].

    The gradient of l with respect to w is:
        dl / dw_1 = 4 * dw'''_1 / dw_1 + 5 * dw'''_2 / dw_1 + 6 * dw'''_3 / dw_1
                  = ... =  -5020
        dl / dw_2 = 4 * dw'''_1 / dw_2 + 5 * dw'''_2 / dw_2 + 6 * dw'''_3 / dw_2
                  = ... = -10043
        dl / dw_3 = 4 * dw'''_1 / dw_3 + 5 * dw'''_2 / dw_3 + 6 * dw'''_3 / dw_3
                  = ... = -15066
    """
    test_inputs = torch.tensor([[4., 5., 6.]])
    test_loss = model(test_inputs, params=params_3)
    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == -115499.
    assert len(grads) == 1
    assert torch.all(
        grads[0].data == torch.tensor([[-5020., -10043., -15066.]]))
Exemple #11
0
def train(inputs=[],
          adapt_inputs=[],
          exp_name="maml",
          batch_size=16,
          num_workers=1,
          use_cuda=False,
          num_batches=100,
          step_size=0.4,
          shots=1000,
          test_shots=200,
          save_per=-1,
          eval_per=1,
          learning_rate=0.01,
          first_order=False,
          save_dir=".",
          date="210101",
          seed=None,
          logger_kwargs={},
          exp_params=DotMap()):
    device = torch.device(
        'cuda' if use_cuda and torch.cuda.is_available() else 'cpu')

    def get_loss(output, targets):
        return -1 * output.log_prob(targets).sum(dim=1).mean()

    def get_adapted_params(model, test_batch):
        test_inputs, test_targets = test_batch['train']
        test_in, test_target = test_inputs[0], test_targets[0]
        test_out = model(test_in)
        inner_loss = get_loss(test_out, test_target)
        model.zero_grad()
        with torch.no_grad():
            params = gradient_update_parameters(model,
                                                inner_loss,
                                                step_size=step_size,
                                                first_order=first_order)
            test_out = model(test_in, params=params)
            outer_loss = get_loss(test_out, test_target)
        return params, outer_loss.item()

    from torch.utils.tensorboard import SummaryWriter
    import datetime

    env_name = "BedBathingBaxterHuman-v0217_0-v1"
    env = gym.make('assistive_gym:' + env_name)

    dataset = behaviour(inputs, shots=shots, test_shots=test_shots)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=num_workers)

    adapt_datasets, adapt_loaders = {}, {}
    for key, adapt_dir in adapt_inputs.items():
        adapt_datasets[key] = behaviour([adapt_dir],
                                        shots=shots,
                                        test_shots=test_shots)
        adapt_loaders[key] = BatchMetaDataLoader(adapt_datasets[key],
                                                 batch_size=1,
                                                 shuffle=True,
                                                 num_workers=num_workers)

    model = PolicyNetwork(env.observation_space_human.shape[0],
                          env.action_space_human.shape[0])
    model.to(device=device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    now = datetime.datetime.now()
    hour = '{:02d}'.format(now.hour)
    minute = '{:02d}'.format(now.minute)
    second = '{:02d}'.format(now.second)
    timestamp = '{}-{}-{}'.format(hour, minute, second)

    # new_data/date/MAML_assistive_gym_sz0-1_lr0-01_s50_ts200/MAML_assistive_gym_sz0-1_lr0-01_s50_ts200_s0
    output_dir = logger_kwargs['output_dir']
    log_folder = os.path.join(save_dir, date, "runs",
                              f"{exp_name}_{timestamp}")
    print(f"Saving logs to {log_folder}")
    os.makedirs(log_folder, exist_ok=True)
    writer = SummaryWriter(log_dir=log_folder, comment=f"{exp_name}")

    rllib_saver = RLLibSaver()
    adapt_losses = {key: [] for key in adapt_loaders.keys()}
    # Training loop
    with tqdm(dataloader, total=num_batches, disable=True) as pbar:
        for batch_idx, batches in enumerate(
                zip(pbar, *list(adapt_loaders.values()))):
            model.zero_grad()

            main_batch = batches[0]
            train_inputs, train_targets = main_batch['train']
            train_inputs = train_inputs.to(device=device).float()
            train_targets = train_targets.to(device=device).float()

            test_inputs, test_targets = main_batch['test']
            test_inputs = test_inputs.to(device=device).float()
            test_targets = test_targets.to(device=device).float()

            outer_loss = torch.tensor(0., device=device)
            loss = torch.tensor(0., device=device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(
                               zip(train_inputs, train_targets, test_inputs,
                                   test_targets)):
                train_output = model(train_input)
                inner_loss = get_loss(train_output, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=step_size,
                                                    first_order=first_order)

                test_output = model(test_input, params=params)
                outer_loss += get_loss(test_output, test_target)
                with torch.no_grad():
                    loss += get_loss(test_output, test_target)

            outer_loss.div_(batch_size)
            loss.div_(batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            # Report progress
            pbar.set_postfix(loss='{0:.4f}'.format(loss.item()))
            print(f"Iter {batch_idx} train loss: {loss.item():.3f}")
            writer.add_scalar(f"train/maml_loss", loss.item(), batch_idx)
            writer.flush()

            # Eval & Save model
            do_eval = batch_idx % eval_per == 0
            do_save = output_dir is not None and save_per > 0 and (
                (batch_idx % save_per == 0) or (batch_idx == num_batches))
            if do_eval:
                all_pre_params = []
                all_post_params = []
                all_inputs = []

                for adapt_key, adapt_batch in zip(list(adapt_loaders.keys()),
                                                  batches[1:]):

                    pre_params = OrderedDict(model.meta_named_parameters())
                    post_params, outer_loss = get_adapted_params(
                        model, adapt_batch)

                    all_pre_params.append(pre_params)
                    all_post_params.append(post_params)
                    all_inputs.append(
                        adapt_batch['train'][0][0])  # inputs, idx=1
                    print(f"Save inner loss {adapt_key}: {outer_loss:.04f}")
                    if do_save:
                        rllib_saver.save(params=post_params,
                                         save_path=output_dir,
                                         key=adapt_key,
                                         iteration=batch_idx,
                                         exp_params=exp_params)
                    adapt_losses[adapt_key].append(outer_loss)
                adapt_keys = list(adapt_loaders.keys())
                if len(adapt_keys) > 1:
                    _, _, fig = cluster_activation(model, all_inputs,
                                                   all_pre_params, adapt_keys,
                                                   "fc3")
                    writer.add_figure(f"train/fc3_before", fig, batch_idx)
                    writer.flush()
                    _, _, fig = cluster_activation(model, all_inputs,
                                                   all_post_params, adapt_keys,
                                                   "fc3")
                    writer.add_figure(f"train/fc3_after", fig, batch_idx)
                    writer.flush()
                with open(os.path.join(output_dir, "adapt_losses.txt"),
                          "w+") as f:
                    yaml.dump(adapt_losses, f)

            if batch_idx >= num_batches:
                break
Exemple #12
0
def train(ml_custom, policies, args):
    # Prepare to log info
    writer = SummaryWriter()

    # Define model
    model = MIL()
    model.to(device=args.device)
    # load_model(model, "./models/mil_499.th")
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    print("Start training")
    for batch in range(args.num_batches):
        outer_loss = torch.tensor(0., device=args.device)
        accuracy = torch.tensor(0., device=args.device)

        for name in np.random.choice(list(ml_custom.keys()), 3, replace=False):
            env_cls = ml_custom[name]
            print("Task: %s" % name)

            policy = policies[name]
            all_tasks = [
                task for task in ml45.train_tasks if task.env_name == name
            ]

            # Adapt in support task
            env = env_cls()
            support_task = random.choice(all_tasks[:25])
            env.set_task(support_task)
            batches_imgs, batches_configs, batches_actions = get_data(
                env, policy, args)
            inner_loss = torch.tensor(0., device=args.device)
            number_batches = len(batches_imgs)
            while (len(batches_imgs) > 0):
                pred_actions = model(
                    batches_imgs.pop().to(device=args.device),
                    batches_configs.pop().to(device=args.device))
                inner_loss += F.mse_loss(
                    pred_actions,
                    batches_actions.pop().to(device=args.device))
            inner_loss.div_(number_batches)
            model.zero_grad()
            params = gradient_update_parameters(model,
                                                inner_loss,
                                                step_size=args.step_size,
                                                first_order=args.first_order)

            # Evaluate in query task
            env = env_cls()
            query_task = random.choice(all_tasks[25:])
            env.set_task(support_task)
            batches_imgs, batches_configs, batches_actions = get_data(
                env, policy, args)
            aux_loss = torch.tensor(0., device=args.device)
            aux_accuracy = torch.tensor(0., device=args.device)
            number_batches = len(batches_imgs)

            while (len(batches_imgs) > 0):
                pred_actions = model(
                    batches_imgs.pop().to(device=args.device),
                    batches_configs.pop().to(device=args.device))
                batch_actions = batches_actions.pop().to(device=args.device)
                aux_loss += F.mse_loss(pred_actions, batch_actions)
                with torch.no_grad():
                    aux_accuracy += get_accuracy(pred_actions, batch_actions)

            aux_loss.div_(number_batches)
            aux_accuracy.div_(number_batches)
            outer_loss += aux_loss
            accuracy += aux_accuracy

        outer_loss.div_(3)
        accuracy.div_(3)
        meta_optimizer.zero_grad()
        outer_loss.backward()
        meta_optimizer.step()

        #Log info
        writer.add_scalar('meta_train/loss', outer_loss.item(), batch)
        writer.add_scalar('meta_train/accuracy', accuracy.item(), batch)
        print("batch: %d loss: %.4f accuracy: %.4f" %
              (batch, outer_loss.item(), accuracy.item()))

        # Save model
        save_model(model, args.output_folder, 'mil_%d.th' % batch)
Exemple #13
0
def train(args):

    dataset = clinic(shots=args.num_shots,
                       ways=args.num_ways,
                       shuffle=True,
                       test_shots=15,
                       meta_train=True,
                       download=args.download)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model = ConvolutionalNeuralNetwork(1,
                                       args.num_ways,
                                       hidden_size=args.hidden_size)
    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(dataloader, total=args.num_batches) as pbar:
        for batch_idx, batch in enumerate(pbar):
            model.zero_grad()

            train_inputs, train_targets = batch['train']
            train_inputs = train_inputs.to(device=args.device)
            train_targets = train_targets.to(device=args.device)

            test_inputs, test_targets = batch['test']
            test_inputs = test_inputs.to(device=args.device)
            test_targets = test_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                    test_target) in enumerate(zip(train_inputs, train_targets,
                    test_inputs, test_targets)):
                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
            '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)
Exemple #14
0
            support_inputs, support_targets = [_.cuda(non_blocking=True) for _ in train_batch['train']] if args.use_cuda else [_ for _ in train_batch['train']]
            query_inputs, query_targets = [_.cuda(non_blocking=True) for _ in train_batch['test']] if args.use_cuda else [_ for _ in train_batch['test']]
           
            train_loss = torch.tensor(0., device=support_inputs.device)
            train_acc = torch.tensor(0., device=support_inputs.device)

            for _ , (support_input, support_target, query_input,
                    query_target) in enumerate(zip(support_inputs, support_targets,
                    query_inputs, query_targets)):
                #meta inner loop
                support_logit = model(support_input)           
                train_inner_loss = F.cross_entropy(support_logit, support_target)

                model.zero_grad()
                params = gradient_update_parameters(model, train_inner_loss,
                    step_size=args.step_size, first_order=args.first_order)

                #meta outer loop
                if train_batch_i==int(args.train_tasks/args.batch_tasks)-1:
                    teacher_model.eval()
                    teacher_query_logit = teacher_model(query_input)
                    query_logit = model(query_input, params=params)
                    train_loss += get_loss(args, query_logit, query_target, teacher_query_logit) 
                
                else:
                    query_logit = model(query_input, params=params)
                    train_loss += F.cross_entropy(query_logit, query_target)
                
                
                with torch.no_grad():
                    train_acc += count_acc(query_logit, query_target)
Exemple #15
0
def train(args):

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock)

    TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15)
    TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers)

    model = MLP(out_features = args.num_ways)

    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(TrainLoader, total=args.num_batches) as pbar:
        for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar):
            model.zero_grad()

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            inputs_query = inputs_query.to(device=args.device)
            target_query = target_query.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size_train)
            accuracy.div_(args.batch_size_train)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches or accuracy.item() > 0.95:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)

    # start test
    test_support, test_query, test_target_support, test_target_query \
        = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset))
    test_query = torch.from_numpy(test_query).to(device=args.device)
    test_target_query = torch.from_numpy(test_target_query).to(device=args.device)

    TrainSet = perturbdataloader_test(test_support, test_target_support)
    TrainLoader = DataLoader(TrainSet, args.batch_size_test)

    meta_optimizer.zero_grad()
    inner_losses = []
    accuracy_test = []

    for epoch in range(args.num_epoch):
        model.to(device=args.device)
        model.train()

        for _, (inputs_support,target_support) in enumerate(TrainLoader):

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            train_logit = model(inputs_support)
            loss = F.cross_entropy(train_logit, target_support)
            inner_losses.append(loss)
            loss.backward()
            meta_optimizer.step()
            meta_optimizer.zero_grad()

            test_logit = model(test_query)
            with torch.no_grad():
                accuracy = get_accuracy(test_logit, test_target_query)
                accuracy_test.append(accuracy)



        if (epoch + 1) % 3 == 0:
            print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))