bs_outter = 0 if args.cross_meta_learning and cross_flag: model.encoder.out.weight.data[SP:, :].zero_( ) # zero out the embeddings as we come to a new language (no-self embeddings) for j in range(args.inter_size): if not args.cross_meta_learning: meta_train_batch = next(iter(aux_reals[selected])) lang_U = Us[selected + 1] else: meta_train_batch = next(iter(aux_reals[selected2])) lang_U = Us[selected2 + 1] inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare( meta_train_batch, U=lang_U) loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) / args.inter_size loss.backward() loss_outer = loss_outer + loss bs_outter = bs_outter + batch_size * max(inputs.size(1), targets.size(1)) # update the meta-parameters if not args.no_meta_training: model.load_fast_weights(weights) if args.meta_approx_2nd: meta_grad = model.save_fast_gradients('meta')
fast_weights = inner_loop(args, (aux_reals[j], args.aux[j]), model, weights, iters=iters, save_diff=True, progressbar=progressbar) all_fast_weights.append(fast_weights) # save the increamentals fast_gradients = model.combine_fast_weights(all_fast_weights, type='meta') # ------ outer-upate for reptile (parallel mode or sequential mode.) meta_opt.param_groups[0]['lr'] = get_learning_rate( iters + 1, disable=args.disable_lr_schedule) meta_opt.zero_grad() # -- virtual batch, only used to build the backward pass inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare( next(iter(aux_reals[selected]))) loss_outer = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks)) loss_outer.backward() # -- load fast gradient... model.load_fast_gradients(fast_gradients, type='meta') meta_opt.step() info = 'Outer-loop (all): lr={:.8f}, loss (fake) ={}'.format( meta_opt.param_groups[0]['lr'], export(loss_outer)) # progressbar.set_description(info) if args.tensorboard and (not args.debug):