print(model) ################################################# # Setup Optimizers, only Adam works for now. ################################################# if args.num_train_enc != 0: enc_optimizer = optim.Adam(model.enc.parameters(), lr=args.enc_lr) if args.num_train_dec != 0: dec_optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.dec.parameters()), lr=args.dec_lr) general_optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.dec_lr) ################################################# # Training Processes ################################################# report_loss, report_ber = [], [] for epoch in range(1, args.num_epoch + 1): if args.joint_train == 1: for idx in range(args.num_train_enc + args.num_train_dec): train(epoch, model, general_optimizer, args,
################################################################## if args.optimizer == 'lookahead': print('Using Lookahead Optimizers') from optimizers import Lookahead lookahead_k = 5 lookahead_alpha = 0.5 if args.num_train_enc != 0 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: # no optimizer for encoder enc_base_opt = optim.Adam(model.enc.parameters(), lr=args.enc_lr) enc_optimizer = Lookahead(enc_base_opt, k=lookahead_k, alpha=lookahead_alpha) if args.num_train_dec != 0: dec_base_opt = optim.Adam(filter(lambda p: p.requires_grad, model.dec.parameters()), lr=args.dec_lr) dec_optimizer = Lookahead(dec_base_opt, k=lookahead_k, alpha=lookahead_alpha) general_base_opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=args.dec_lr) general_optimizer = Lookahead(general_base_opt, k=lookahead_k, alpha=lookahead_alpha) else: # Adam, SGD, etc.... if args.optimizer == 'adam': OPT = optim.Adam elif args.optimizer == 'sgd': OPT = optim.SGD else: OPT = optim.Adam if args.num_train_enc != 0 and args.encoder not in ['Turbo_rate3_lte', 'Turbo_rate3_757']: # no optimizer for encoder enc_optimizer = OPT(model.enc.parameters(),lr=args.enc_lr) if args.num_train_dec != 0: dec_optimizer = OPT(filter(lambda p: p.requires_grad, model.dec.parameters()), lr=args.dec_lr)