Exemplo n.º 1
0
class LeapWrapper(BaseWrapper):
    """Wrapper around the Leap meta-learner.

    Arguments:
        model (nn.Module): classifier.
        optimizer_cls: optimizer class.
        meta_optimizer_cls: meta optimizer class.
        optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
        meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
            construction.
        meta_kwargs (dict): kwargs to pass to meta-learner upon construction.
        criterion (func): loss criterion to use.
    """
    def __init__(self, model, optimizer_cls, meta_optimizer_cls,
                 optimizer_kwargs, meta_optimizer_kwargs, meta_kwargs,
                 criterion):
        super(LeapWrapper, self).__init__(criterion, model, optimizer_cls,
                                          optimizer_kwargs)
        self.meta = Leap(model, **meta_kwargs)

        self.meta_optimizer_cls = \
            optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
        self.meta_optimizer = self.meta_optimizer_cls(self.meta.parameters(),
                                                      **meta_optimizer_kwargs)

    def _partial_meta_update(self, l, final):
        self.meta.update(l, self.model)

    def _final_meta_update(self):
        self.meta.normalize()
        self.meta_optimizer.step()
        self.meta_optimizer.zero_grad()

    def run_task(self, task, train, meta_train):
        if meta_train:
            self.meta.init_task()

        if train:
            self.meta.to(self.model)

        return super(LeapWrapper, self).run_task(task, train, meta_train)
Exemplo n.º 2
0
def leap_adapt(model,
               source_corpus,
               target_corpus,
               char2idx,
               args,
               device,
               lang_model_n_words=0):
    model = model.to(device)
    leap = Leap(model)
    meta_optimizer = torch.optim.Adam(leap.parameters(),
                                      lr=args.leap_meta_lr_init)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        meta_optimizer,
        factor=args.lr_decay,
        patience=args.patience,
        threshold=args.threshold)
    best_score = 3

    for meta_epoch in np.arange(args.n_meta_epochs):
        source_valid_cosine = []
        target_valid_cosine = []

        model.train()
        for meta_batch in np.arange(args.n_meta_batch):
            meta_optimizer.zero_grad()

            leap.init_task()
            leap.to(model)
            inner_optimizer = torch.optim.Adam(model.parameters(),
                                               lr=args.leap_inner_lr_init)
            for inner_batch in np.arange(args.n_task_steps):
                inner_optimizer.zero_grad()
                source_train_contexts, source_train_targets, source_train_vocabs = source_corpus.get_batch(
                    args.meta_batch_size,
                    args.n_shot,
                    char2idx,
                    device,
                    fixed=args.fixed_shot)
                pred_emb = model.forward(source_train_contexts,
                                         source_train_vocabs)
                loss = -nn.functional.cosine_similarity(
                    pred_emb, source_train_targets).mean()
                loss.backward()
                leap.update(loss, model)
                inner_optimizer.step()

            leap.init_task()
            leap.to(model)
            inner_optimizer = torch.optim.Adam(model.parameters(),
                                               lr=args.leap_inner_lr_init)
            for inner_batch in np.arange(args.n_task_steps):
                inner_optimizer.zero_grad()
                target_train_contexts, target_train_targets, target_train_vocabs = target_corpus.get_batch(
                    args.meta_batch_size,
                    args.n_shot,
                    char2idx,
                    device,
                    fixed=args.fixed_shot,
                    repeat_ctxs=args.meta_repeat_ctxs)
                pred_emb = model.forward(target_train_contexts,
                                         target_train_vocabs)
                loss = -nn.functional.cosine_similarity(
                    pred_emb, target_train_targets).mean()
                loss.backward()
                leap.update(loss, model)
                inner_optimizer.step()

            leap.normalize()
            meta_optimizer.step()

        leap.to(model)
        model.eval()
        with torch.no_grad():
            for batch in np.arange(args.n_batch):
                source_valid_contexts, source_valid_targets, source_valid_vocabs = source_corpus.get_batch(
                    args.meta_batch_size,
                    args.n_shot,
                    char2idx,
                    device,
                    use_valid=True,
                    fixed=args.fixed_shot)
                pred_emb = model.forward(source_valid_contexts,
                                         source_valid_vocabs)
                loss = -nn.functional.cosine_similarity(
                    pred_emb, source_valid_targets).mean()
                source_valid_cosine += [loss.cpu().numpy()]

                target_valid_contexts, target_valid_targets, target_valid_vocabs = target_corpus.get_batch(
                    args.meta_batch_size,
                    args.n_shot,
                    char2idx,
                    device,
                    use_valid=True,
                    fixed=args.fixed_shot,
                    repeat_ctxs=args.meta_repeat_ctxs)
                pred_emb = model.forward(target_valid_contexts,
                                         target_valid_vocabs)
                loss = -nn.functional.cosine_similarity(
                    pred_emb, target_valid_targets).mean()
                target_valid_cosine += [loss.cpu().numpy()]

        avg_source_valid, avg_target_valid = np.average(
            source_valid_cosine), np.average(target_valid_cosine)
        score = avg_target_valid
        lr_scheduler.step(score)
        print(
            f"Average source cosine loss: {avg_source_valid}; Average target cosine loss: {avg_target_valid}"
        )

        if score < best_score:
            best_score = score
            torch.save(model.state_dict(),
                       os.path.join(args.save_dir, 'leap_model.pt'))

        if meta_optimizer.param_groups[0]['lr'] < args.leap_lr_early_stop:
            print('LR early stop')
            break
Exemplo n.º 3
0
type(net.collect_params())

net.collect_params().initialize(mx.init.Normal(sigma=1.), ctx=model_ctx)

square_loss = gluon.loss.L2Loss()

epochs = 3
loss_sequence = []
num_batches = num_examples / batch_size

verbose = 1

##########
meta_steps = 10
leap = Leap(net)
meta_trainer = gluon.Trainer(list(leap.parameters()), 'sgd',
                             {'learning_rate': 0.0001})
meta_logger = MetaLogger(num_tasks)
log_params = True
##########

for ms in range(meta_steps):
    for task in range(num_tasks):
        train_data = train_data_all[task]

        leap.to(net)
        leap.init_task()

        trainer = gluon.Trainer(net.collect_params(), 'sgd',
                                {'learning_rate': 0.0001})
        if log_params: