Пример #1
0
    bs_outter = 0

    if args.cross_meta_learning and cross_flag:
        model.encoder.out.weight.data[SP:, :].zero_(
        )  # zero out the embeddings as we come to a new language (no-self embeddings)

    for j in range(args.inter_size):

        if not args.cross_meta_learning:
            meta_train_batch = next(iter(aux_reals[selected]))
            lang_U = Us[selected + 1]
        else:
            meta_train_batch = next(iter(aux_reals[selected2]))
            lang_U = Us[selected2 + 1]

        inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(
            meta_train_batch, U=lang_U)
        loss = model.cost(targets,
                          target_masks,
                          out=model(encoding, source_masks, inputs,
                                    input_masks)) / args.inter_size
        loss.backward()

        loss_outer = loss_outer + loss
        bs_outter = bs_outter + batch_size * max(inputs.size(1),
                                                 targets.size(1))

    # update the meta-parameters
    if not args.no_meta_training:
        model.load_fast_weights(weights)
        if args.meta_approx_2nd:
            meta_grad = model.save_fast_gradients('meta')
Пример #2
0
        fast_weights = inner_loop(args, (aux_reals[j], args.aux[j]),
                                  model,
                                  weights,
                                  iters=iters,
                                  save_diff=True,
                                  progressbar=progressbar)
        all_fast_weights.append(fast_weights)  # save the increamentals
    fast_gradients = model.combine_fast_weights(all_fast_weights, type='meta')

    # ------ outer-upate for reptile (parallel mode or sequential mode.)
    meta_opt.param_groups[0]['lr'] = get_learning_rate(
        iters + 1, disable=args.disable_lr_schedule)
    meta_opt.zero_grad()

    # -- virtual batch, only used to build the backward pass
    inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(
        next(iter(aux_reals[selected])))
    loss_outer = model.cost(targets,
                            target_masks,
                            out=model(encoding, source_masks, inputs,
                                      input_masks))
    loss_outer.backward()

    # -- load fast gradient...
    model.load_fast_gradients(fast_gradients, type='meta')
    meta_opt.step()

    info = 'Outer-loop (all): lr={:.8f}, loss (fake) ={}'.format(
        meta_opt.param_groups[0]['lr'], export(loss_outer))
    # progressbar.set_description(info)

    if args.tensorboard and (not args.debug):