def test_load_state_dict(self): # define simple FP16 model model = torch.nn.Linear(5, 5).cuda().half() params = list(model.parameters()) # initialize memory efficient FP16 optimizer # with pseudo DictConfigs optimizer = FairseqAdam( cfg=OmegaConf.create( vars( argparse.Namespace( adam_betas="(0.9, 0.999)", adam_eps=1e-8, weight_decay=0.0, lr=[0.00001], ) ) ), params=params, ) me_optimizer = MemoryEfficientFP16Optimizer( cfg=OmegaConf.create( { "common": vars( argparse.Namespace( fp16_init_scale=1, fp16_scale_window=1, fp16_scale_tolerance=1, threshold_loss_scale=1, min_loss_scale=1e-4, ) ) } ), params=params, optimizer=optimizer, ) # optimizer state is created in the first step loss = model(torch.rand(5).cuda().half()).sum() me_optimizer.backward(loss) me_optimizer.step() # reload state state = me_optimizer.state_dict() me_optimizer.load_state_dict(state) for k, v in me_optimizer.optimizer.state.items(): self.assertTrue(k.dtype == torch.float16) for v_i in v.values(): if torch.is_tensor(v_i): self.assertTrue(v_i.dtype == torch.float32)
def get_fairseq_adamw_optimizer(model: nn.Module, args): cfg = FairseqOptCfg(args.train.learning_rate, args.train.adam_betas, args.train.adam_eps, args.train.weight_decay) return FairseqAdam(cfg, model.parameters()).optimizer
def get_fairseq_adamw_optimizer(model: nn.Module, args): setattr(args, 'lr', [args.learning_rate]) return FairseqAdam(args, model.parameters()).optimizer
def parse_train_arg(): parser = ArgumentParser() parser.add_argument('--task', type=str, default='vanilla', choices=['vanilla', 'vertical_attention']) parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--cpu", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--data-dir', type=Path, required=True) parser.add_argument('--output-dir', type=Path, required=True) parser.add_argument("--base-model-name", type=str, required=False, help="Bert pre-trained table_bert selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", default='bert-base-uncased') parser.add_argument("--table-bert-extra-config", type=json.loads, default='{}') parser.add_argument('--no-init', action='store_true', default=False) # parser.add_argument('--config-file', type=Path, help='table_bert config file if do not use pre-trained BERT table_bert.') # distributed training parser.add_argument("--ddp-backend", type=str, default='pytorch', choices=['pytorch', 'apex']) parser.add_argument("--local_rank", "--local-rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--master-port", type=int, default=-1, help="Master port (for multi-node SLURM jobs)") parser.add_argument("--debug-slurm", action='store_true', help="Debug multi-GPU / multi-node within a SLURM job") # training details parser.add_argument("--train-batch-size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--max-epoch", default=-1, type=int) # parser.add_argument("--total-num-update", type=int, default=1000000, help="Number of steps to train for") parser.add_argument('--gradient-accumulation-steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument("--lr-scheduler", type=str, default='polynomial_decay', help='Learning rate scheduler') parser.add_argument("--optimizer", type=str, default='adam', help='Optimizer to use') parser.add_argument('--lr', '--learning-rate', default='0.00005', type=eval_str_list, metavar='LR_1,LR_2,...,LR_N', help='learning rate for the first N epochs; all epochs >N using LR_N' ' (note: this may be interpreted differently depending on --lr-scheduler)') parser.add_argument('--clip-norm', default=0., type=float, help='clip gradient') parser.add_argument('--empty-cache-freq', default=0, type=int, help='how often to clear the PyTorch CUDA cache (0 to disable)') parser.add_argument('--save-checkpoint-every-niter', default=10000, type=int) FairseqAdam.add_args(parser) PolynomialDecaySchedule.add_args(parser) # FP16 training parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--memory-efficient-fp16', action='store_true', help='Use memory efficient fp16') parser.add_argument('--threshold-loss-scale', type=float, default=None) parser.add_argument('--fp16-init-scale', type=float, default=128) # parser.add_argument('--fp16-scale-window', type=int, default=0) parser.add_argument('--fp16-scale-tolerance', type=float, default=0.0) parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', help='minimum FP16 loss scale, after which training is stopped') parser.add_argument('--debug-dataset', default=False, action='store_true') args = parser.parse_args() model_cls = task_dict[args.task]['model'] if hasattr(model_cls, 'add_args'): model_cls.add_args(parser) args = parser.parse_args() return args