예제 #1
0
def inner_loop(args,
               data,
               model,
               weights,
               iters=0,
               save_diff=True,
               self_opt=None):
    model.train()
    data_loader, data_name = data
    progressbar = tqdm(total=args.inner_steps,
                       desc='start training for {}'.format(data_name))

    model.load_fast_weights(weights)
    if self_opt is None:
        self_opt = torch.optim.Adam(
            [p for p in model.get_parameters(type='fast') if p.requires_grad],
            betas=(0.9, 0.98),
            eps=1e-9)  # reset the optimizer

    for i in range(args.inner_steps):
        self_opt.param_groups[0]['lr'] = get_learning_rate(
            iters + i + 1, disable=args.disable_lr_schedule)
        self_opt.zero_grad()
        loss_inner = 0
        bs_inner = 0
        for j in range(args.inter_size):
            train_batch = next(iter(data_loader))
            inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(
                train_batch)
            loss = model.cost(targets,
                              target_masks,
                              out=model(encoding, source_masks, inputs,
                                        input_masks)) / args.inter_size
            loss.backward()

            loss_inner = loss_inner + loss
            bs_inner = bs_inner + batch_size * max(inputs.size(1),
                                                   targets.size(1))

        # update the fast-weights
        self_opt.step()
        info = '  Inner-loop[{}]: loss={:.3f}, lr={:.8f}, batch_size={}'.format(
            data_name, export(loss_inner), self_opt.param_groups[0]['lr'],
            bs_inner)
        progressbar.update(1)
        progressbar.set_description(info)

    progressbar.close()

    if save_diff:
        return model.save_fast_weights(weights=weights)  # fast-weights
    return model.save_fast_weights()
예제 #2
0
def inner_loop(args,
               data,
               model,
               weights=None,
               iters=0,
               inner_steps=None,
               self_opt=None,
               use_prog_bar=True,
               inner_loop_data=None):

    set_random(seeds[iters])
    model.train()

    data_loader, data_name = data
    lang_id = data_id[data_name]
    lang_U = Us[lang_id]

    flag = isinstance(data_loader, (list, ))

    if inner_steps is None:
        inner_steps = args.inner_steps

    if use_prog_bar:
        progressbar = tqdm(total=inner_steps,
                           desc='start training for {}'.format(data_name))

    if weights is not None:
        model.load_fast_weights(weights)

    if self_opt is None:
        self_opt = torch.optim.Adam([
            p for p in model.get_parameters(type=args.finetune_params)
            if p.requires_grad
        ],
                                    betas=(0.9, 0.98),
                                    eps=1e-9)  # reset the optimizer

    step = 0
    for i in range(inner_steps):
        self_opt.param_groups[0]['lr'] = get_learning_rate(
            iters + i + 1, disable=args.disable_lr_schedule)
        self_opt.zero_grad()
        loss_inner = 0
        bs_inner = 0
        for j in range(args.inter_size):
            if not flag:
                train_batch = next(iter(data_loader))
            else:
                train_batch = data_loader[step]
            step += 1

            if inner_loop_data is not None:
                inner_loop_data.append(train_batch)

            inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(
                train_batch, U=lang_U)
            loss = model.cost(targets,
                              target_masks,
                              out=model(encoding, source_masks, inputs,
                                        input_masks)) / args.inter_size
            loss.backward()

            loss_inner = loss_inner + loss
            bs_inner = bs_inner + batch_size * max(inputs.size(1),
                                                   targets.size(1))

        # update the fast-weights
        self_opt.step()
        info = '  Inner-loop[{}]: loss={:.3f}, lr={:.8f}, batch_size={}'.format(
            data_name, export(loss_inner), self_opt.param_groups[0]['lr'],
            bs_inner)

        if use_prog_bar:
            progressbar.update(1)
            progressbar.set_description(info)

    if use_prog_bar:
        progressbar.close()
    return model.save_fast_weights()
예제 #3
0
            # fast_weights3 = inner_loop(args, (inner_loop_data, args.aux[selected]), model, iters = iters, use_prog_bar=False)

            # compute new gradient:
            for name in meta_grad:
                if name in fast_weights:
                    meta_grad[name].add_(
                        0.01 * (fast_weights[name] - fast_weights2[name]) /
                        args.approx_lr)

            # print(model.grad_sum(meta_grad))
            model.load_fast_gradients(meta_grad, 'meta')
            model.load_fast_weights(weights)

    meta_opt.step()
    info = 'Lang {}: loss={:.3f}, lr={:.8f}, batch_size={}, eposides={}'.format(
        args.aux[selected], export(loss_outer), meta_opt.param_groups[0]['lr'],
        bs_outter, iters)
    progressbar.update(1)
    progressbar.set_description(info)
    tokens = tokens + bs_outter

    if args.tensorboard and (not args.debug):
        writer.add_scalar('train/Loss', export(loss_outer), iters + 1)

    # ---- zero the self-embedding matrix
    if not args.no_meta_training:
        model.encoder.out.weight.data[
            SP:, :].zero_()  # ignore the first special tokens.

    iters = iters + 1
    eposides = eposides + 1
예제 #4
0
            # fast_weights3 = inner_loop(args, (inner_loop_data, args.aux[selected]), model, iters = iters, use_prog_bar=False)

            # compute new gradient:
            for name in meta_grad:
                if name in fast_weights:
                    meta_grad[name].add_(
                        0.01 * (fast_weights[name] - fast_weights2[name]) /
                        args.approx_lr)

            # print(model.grad_sum(meta_grad))
            model.load_fast_gradients(meta_grad, 'meta')
            model.load_fast_weights(weights)

    meta_opt.step()
    info = 'Outer: loss={:.3f}, lr={:.8f}, batch_size={}, eposides={}'.format(
        export(loss_outer), meta_opt.param_groups[0]['lr'], bs_outter, iters)
    progressbar.update(1)
    progressbar.set_description(info)
    tokens = tokens + bs_outter

    if args.tensorboard and (not args.debug):
        writer.add_scalar('train/Loss', export(loss_outer), iters + 1)

    # ---- zero the self-embedding matrix
    if not args.no_meta_training:
        model.encoder.out.weight.data[
            SP:, :].zero_()  # ignore the first special tokens.

    iters = iters + 1
    eposides = eposides + 1
예제 #5
0
def inner_loop(args,
               data,
               model,
               weights,
               iters=0,
               save_diff=True,
               self_opt=None,
               progressbar=None):
    model.train()
    data_loader, data_name = data
    flag = True

    if progressbar is None:
        flag = False
        progressbar = tqdm(total=args.inner_steps,
                           desc='start training for {}'.format(data_name))

    model.load_fast_weights(weights)
    with torch.cuda.device(args.gpu):
        slow_weights = copy.deepcopy(
            model.save_fast_weights(type='slow')
        )  # --- universal embeddings are not updated, but accumurated.
    diff_slow_weights = None

    if self_opt is None:
        self_opt = torch.optim.Adam(
            [p for p in model.get_parameters(type='full') if p.requires_grad],
            betas=(0.9, 0.98),
            eps=1e-9)  # reset the optimizer

    for i in range(args.inner_steps):
        self_opt.param_groups[0]['lr'] = get_learning_rate(
            iters + i + 1, disable=args.disable_lr_schedule)
        self_opt.zero_grad()
        loss_inner = 0
        bs_inner = 0
        for j in range(args.inter_size):
            train_batch = next(iter(data_loader))
            inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(
                train_batch)
            loss = model.cost(targets,
                              target_masks,
                              out=model(encoding, source_masks, inputs,
                                        input_masks)) / args.inter_size
            loss.backward()

            loss_inner = loss_inner + loss
            bs_inner = bs_inner + batch_size * max(inputs.size(1),
                                                   targets.size(1))

        # update the fast-weights
        self_opt.step()
        info = '  Inner-loop[{}]: loss={:.3f}, lr={:.8f}, batch_size={}'.format(
            data_name, export(loss_inner), self_opt.param_groups[0]['lr'],
            bs_inner)
        progressbar.update(1)
        progressbar.set_description(info)

        # accumulate difference
        if diff_slow_weights is None:
            diff_slow_weights = model.save_fast_weights(weights=slow_weights,
                                                        type='slow')
        else:
            diff_slow_weights = model.combine_fast_weights([
                diff_slow_weights,
                model.save_fast_weights(weights=slow_weights, type='slow')
            ],
                                                           type='slow',
                                                           average=False)

        # slow weights remain normal
        model.load_fast_weights(slow_weights, type='slow')

    if not flag:
        progressbar.close()

    if save_diff:
        diff_weights = model.save_fast_weights(weights=weights)
        diff_weights.update(diff_slow_weights)
        return diff_weights
    return model.save_fast_weights()