コード例 #1
0
ファイル: test_utils.py プロジェクト: xjtushujun/pytorch-maml
def test_update_parameters(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of the function f with respect to w, and evaluated
    at w = [2, 3, 5], is:
        df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23
        df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46
        df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69

    The updated parameter w' is then given by one step of gradient descent,
    with step size 0.5:
        w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 * 23 = -9.5
        w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 46 = -20
        w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 68 = -29.5
    """
    train_inputs = torch.tensor([[1., 2., 3.]])
    train_loss = 0.5 * (model(train_inputs)**2)

    params = update_parameters(model,
                               train_loss,
                               params=None,
                               step_size=0.5,
                               first_order=False)

    assert train_loss.item() == 264.5
    assert list(params.keys()) == ['weight']
    assert torch.all(
        params['weight'].data == torch.tensor([[-9.5, -20., -29.5]]))
    """
    The new loss function (still with respect to the weights of the model w) is
    defined as:
        g(w) = 0.5 * (4 * w'_1 + 5 * w'_2 + 6 * w'_3) ** 2
             = 0.5 * (4 * (w_1 - 0.5 * df / dw_1)
                    + 5 * (w_2 - 0.5 * df / dw_2)
                    + 6 * (w_3 - 0.5 * df / dw_3)) ** 2
             = 0.5 * (4 * (w_1 - 0.5 * 1 * (1 * w_1 + 2 * w_2 + 3 * w_3))
                    + 5 * (w_2 - 0.5 * 2 * (1 * w_1 + 2 * w_2 + 3 * w_3))
                    + 6 * (w_3 - 0.5 * 3 * (1 * w_1 + 2 * w_2 + 3 * w_3))) ** 2
             = 0.5 * ((4 - 4 * 0.5 - 5 * 1.0 - 6 * 1.5) * w_1
                    + (5 - 4 * 1.0 - 5 * 2.0 - 6 * 3.0) * w_2
                    + (6 - 4 * 1.5 - 5 * 3.0 - 6 * 4.5) * w_3) ** 2
             = 0.5 * (-12 * w_1 - 27 * w_2 - 42 * w_3) ** 2

    Therefore the gradient of the function g with respect to w (and evaluated
    at w = [2, 3, 5]) is:
        dg / dw_1 = -12 * (-12 * w_1 - 27 * w_2 - 42 * w_3) =  3780
        dg / dw_2 = -27 * (-12 * w_1 - 27 * w_2 - 42 * w_3) =  8505
        dg / dw_3 = -42 * (-12 * w_1 - 27 * w_2 - 42 * w_3) = 13230
    """
    test_inputs = torch.tensor([[4., 5., 6.]])
    test_loss = 0.5 * (model(test_inputs, params=params)**2)

    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == 49612.5
    assert len(grads) == 1
    assert torch.all(grads[0].data == torch.tensor([[3780., 8505., 13230.]]))
コード例 #2
0
ファイル: test_utils.py プロジェクト: xjtushujun/pytorch-maml
def test_update_parameters_first_order(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (4 * w_1 + 5 * w_2 + 6 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of the function f with respect to w, and evaluated
    at w = [2, 3, 5] is:
        df / dw_1 = 4 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 212
        df / dw_2 = 5 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 265
        df / dw_3 = 6 * (4 * w_1 + 5 * w_2 + 6 * w_3) = 318

    The updated parameter w' is then given by one step of gradient descent,
    with step size 0.5:
        w'_1 = w_1 - 0.5 * df / dw_1 = 2 - 0.5 *  9744 = -104
        w'_2 = w_2 - 0.5 * df / dw_2 = 3 - 0.5 * 10416 = -129.5
        w'_3 = w_3 - 0.5 * df / dw_3 = 5 - 0.5 * 12432 = -154
    """
    train_inputs = torch.tensor([[4., 5., 6.]])
    train_loss = 0.5 * (model(train_inputs)**2)

    params = update_parameters(model,
                               train_loss,
                               params=None,
                               step_size=0.5,
                               first_order=True)

    assert train_loss.item() == 1404.5
    assert list(params.keys()) == ['weight']
    assert torch.all(
        params['weight'].data == torch.tensor([[-104., -129.5, -154.]]))
    """
    The new loss function (still with respect to the weights of the model w) is
    defined as:
        g(w) = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2

    Since we computed w' with the first order approximation, the gradient of the
    function g with respect to w, and evaluated at w = [2, 3, 5], is:
        dg / dw_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) =  -825
        dg / dw_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -1650
        dg / dw_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -2475
    """
    test_inputs = torch.tensor([[1., 2., 3.]])
    test_loss = 0.5 * (model(test_inputs, params=params)**2)

    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == 340312.5
    assert len(grads) == 1
    assert torch.all(grads[0].data == torch.tensor([[-825., -1650., -2475.]]))
コード例 #3
0
ファイル: maml.py プロジェクト: sidney1994/pytorch-maml
    def adapt(self, inputs, targets, is_classification_task=None,
              num_adaptation_steps=1, step_size=0.1, first_order=False):
        if is_classification_task is None:
            is_classification_task = (not targets.dtype.is_floating_point)
        params = None

        results = {'inner_losses': np.zeros(
            (num_adaptation_steps,), dtype=np.float32)}

        for step in range(num_adaptation_steps):
            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets)
            results['inner_losses'][step] = inner_loss.item()

            if (step == 0) and is_classification_task:
                results['accuracy_before'] = compute_accuracy(logits, targets)

            self.model.zero_grad()
            params = update_parameters(self.model, inner_loss,
                step_size=step_size, params=params,
                first_order=(not self.model.training) or first_order)

        return params, results
コード例 #4
0
def main(args, mode, iteration=None):
    dataset = load_dataset(args, mode)
    dataloader = BatchMetaDataLoader(dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)

    model.to(device=args.device)
    model.train()

    # To control outer update parameter
    # If you want to control inner update parameter, please see update_parameters function in ./maml/utils.py
    freeze_params = [
        p for name, p in model.named_parameters() if 'classifier' in name
    ]
    learnable_params = [
        p for name, p in model.named_parameters() if 'classifier' not in name
    ]
    if args.outer_fix:
        meta_optimizer = torch.optim.Adam([{
            'params': freeze_params,
            'lr': 0
        }, {
            'params': learnable_params,
            'lr': args.meta_lr
        }])
    else:
        meta_optimizer = torch.optim.Adam([{
            'params': freeze_params,
            'lr': args.meta_lr
        }, {
            'params': learnable_params,
            'lr': args.meta_lr
        }])

    if args.meta_train:
        total = args.train_batches
    elif args.meta_val:
        total = args.valid_batches
    elif args.meta_test:
        total = args.test_batches

    loss_logs, accuracy_logs = [], []

    # Training loop
    with tqdm(dataloader, total=total, leave=False) as pbar:
        for batch_idx, batch in enumerate(pbar):
            if args.centering:
                fc_weight_mean = torch.mean(model.classifier.weight.data,
                                            dim=0)
                model.classifier.weight.data -= fc_weight_mean

            model.zero_grad()

            support_inputs, support_targets = batch['train']
            support_inputs = support_inputs.to(device=args.device)
            support_targets = support_targets.to(device=args.device)

            query_inputs, query_targets = batch['test']
            query_inputs = query_inputs.to(device=args.device)
            query_targets = query_targets.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)

            for task_idx, (support_input, support_target, query_input,
                           query_target) in enumerate(
                               zip(support_inputs, support_targets,
                                   query_inputs, query_targets)):
                support_features, support_logit = model(support_input)
                inner_loss = F.cross_entropy(support_logit, support_target)

                model.zero_grad()

                params = update_parameters(
                    model,
                    inner_loss,
                    extractor_step_size=args.extractor_step_size,
                    classifier_step_size=args.classifier_step_size,
                    first_order=args.first_order)

                query_features, query_logit = model(query_input, params=params)
                outer_loss += F.cross_entropy(query_logit, query_target)

                with torch.no_grad():
                    accuracy += get_accuracy(query_logit, query_target)

            outer_loss.div_(args.batch_size)
            accuracy.div_(args.batch_size)
            loss_logs.append(outer_loss.item())
            accuracy_logs.append(accuracy.item())

            if args.meta_train:
                outer_loss.backward()
                meta_optimizer.step()

            postfix = {
                'mode': mode,
                'iter': iteration,
                'acc': round(accuracy.item(), 5)
            }
            pbar.set_postfix(postfix)
            if batch_idx + 1 == total:
                break

    # Save model
    if args.meta_train:
        filename = os.path.join(args.output_folder,
                                args.dataset + '_' + args.save_name, 'models',
                                'epochs_{}.pt'.format((iteration + 1) * total))
        if (iteration + 1) * total % 5000 == 0:
            with open(filename, 'wb') as f:
                state_dict = model.state_dict()
                torch.save(state_dict, f)

    return loss_logs, accuracy_logs
コード例 #5
0
ファイル: test_utils.py プロジェクト: xjtushujun/pytorch-maml
def test_multiple_update_parameters(model):
    """
    The loss function (with respect to the weights of the model w) is defined as
        f(w) = 0.5 * (1 * w_1 + 2 * w_2 + 3 * w_3) ** 2
    with w = [2, 3, 5].

    The gradient of f with respect to w is:
        df / dw_1 = 1 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 23
        df / dw_2 = 2 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 46
        df / dw_3 = 3 * (1 * w_1 + 2 * w_2 + 3 * w_3) = 69

    The updated parameters are given by:
        w'_1 = w_1 - 1. * df / dw_1 = 2 - 1. * 23 = -21
        w'_2 = w_2 - 1. * df / dw_2 = 3 - 1. * 46 = -43
        w'_3 = w_3 - 1. * df / dw_3 = 5 - 1. * 69 = -64
    """
    train_inputs = torch.tensor([[1., 2., 3.]])

    train_loss_1 = 0.5 * (model(train_inputs)**2)
    params_1 = update_parameters(model,
                                 train_loss_1,
                                 params=None,
                                 step_size=1.,
                                 first_order=False)

    assert train_loss_1.item() == 264.5
    assert list(params_1.keys()) == ['weight']
    assert torch.all(
        params_1['weight'].data == torch.tensor([[-21., -43., -64.]]))
    """
    The new loss function is defined as
        g(w') = 0.5 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) ** 2
    with w' = [-21, -43, -64].

    The gradient of g with respect to w' is:
        dg / dw'_1 = 1 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -299
        dg / dw'_2 = 2 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -598
        dg / dw'_3 = 3 * (1 * w'_1 + 2 * w'_2 + 3 * w'_3) = -897

    The updated parameters are given by:
        w''_1 = w'_1 - 1. * dg / dw'_1 = -21 - 1. * -299 = 278
        w''_2 = w'_2 - 1. * dg / dw'_2 = -43 - 1. * -598 = 555
        w''_3 = w'_3 - 1. * dg / dw'_3 = -64 - 1. * -897 = 833
    """
    train_loss_2 = 0.5 * (model(train_inputs, params=params_1)**2)
    params_2 = update_parameters(model,
                                 train_loss_2,
                                 params=params_1,
                                 step_size=1.,
                                 first_order=False)

    assert train_loss_2.item() == 44700.5
    assert list(params_2.keys()) == ['weight']
    assert torch.all(
        params_2['weight'].data == torch.tensor([[278., 555., 833.]]))
    """
    The new loss function is defined as
        h(w'') = 0.5 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) ** 2
    with w'' = [278, 555, 833].

    The gradient of h with respect to w'' is:
        dh / dw''_1 = 1 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) =  3887
        dh / dw''_2 = 2 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) =  7774
        dh / dw''_3 = 3 * (1 * w''_1 + 2 * w''_2 + 3 * w''_3) = 11661

    The updated parameters are given by:
        w'''_1 = w''_1 - 1. * dh / dw''_1 = 278 - 1. *  3887 =  -3609
        w'''_2 = w''_2 - 1. * dh / dw''_2 = 555 - 1. *  7774 =  -7219
        w'''_3 = w''_3 - 1. * dh / dw''_3 = 833 - 1. * 11661 = -10828
    """
    train_loss_3 = 0.5 * (model(train_inputs, params=params_2)**2)
    params_3 = update_parameters(model,
                                 train_loss_3,
                                 params=params_2,
                                 step_size=1.,
                                 first_order=False)

    assert train_loss_3.item() == 7554384.5
    assert list(params_3.keys()) == ['weight']
    assert torch.all(
        params_3['weight'].data == torch.tensor([[-3609., -7219., -10828.]]))
    """
    The new loss function is defined as
        l(w) = 4 * w'''_1 + 5 * w'''_2 + 6 * w'''_3
    with w = [2, 3, 5] and w''' = [-3609, -7219, -10828].

    The gradient of l with respect to w is:
        dl / dw_1 = 4 * dw'''_1 / dw_1 + 5 * dw'''_2 / dw_1 + 6 * dw'''_3 / dw_1
                  = ... =  -5020
        dl / dw_2 = 4 * dw'''_1 / dw_2 + 5 * dw'''_2 / dw_2 + 6 * dw'''_3 / dw_2
                  = ... = -10043
        dl / dw_3 = 4 * dw'''_1 / dw_3 + 5 * dw'''_2 / dw_3 + 6 * dw'''_3 / dw_3
                  = ... = -15066
    """
    test_inputs = torch.tensor([[4., 5., 6.]])
    test_loss = model(test_inputs, params=params_3)
    grads = torch.autograd.grad(test_loss, model.parameters())

    assert test_loss.item() == -115499.
    assert len(grads) == 1
    assert torch.all(
        grads[0].data == torch.tensor([[-5020., -10043., -15066.]]))