def testcase_fused( B=3, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0, device=torch.device('cpu'), dtype=torch.float, ): if B > 1 and isinstance(lr, (int, float)): lr = [random.uniform(0.5, 2.0) for _ in range(B)] kwargs = {'device': device, 'dtype': dtype} net_array = [_TestNet(**kwargs) for _ in range(B)] net_fused = _TestNet(B=B, **kwargs) optimizer_array = [ optim.Adadelta( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), rho=index_array_or_return_scalar(rho, b), eps=index_array_or_return_scalar(eps, b), weight_decay=index_array_or_return_scalar(weight_decay, b), ) for b in range(B) ] optimizer_fused = get_hfta_optim_for(optim.Adadelta, B=B)( net_fused.parameters(), lr=lr, rho=rho, eps=eps, weight_decay=weight_decay, ) _optim_testing_procedure(net_fused, net_array, optimizer_fused, optimizer_array)
def testcase_StepLR_fused(B=3, step_size=2, gamma=0.1, last_epoch=-1): lr = random.choice([torch.rand((B,)), random.random()]) net_array = [_TestNet() for _ in range(B)] net_fused = _TestNet(B=B) optimizer_array = [ optim.Adadelta( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), ) for b in range(B) ] optimizer_fused = get_hfta_optim_for(optim.Adadelta, B=B)( net_fused.parameters(), lr=lr, ) if not isinstance(last_epoch, int) or last_epoch != -1: _init_initial_lr(optimizer_fused, optimizer_array) lr_scheduler_array = [ lr_scheduler.StepLR( optimizer_array[b], index_array_or_return_scalar(step_size, b), gamma=index_array_or_return_scalar(gamma, b), last_epoch=index_array_or_return_scalar(last_epoch, b), ) for b in range(B) ] lr_scheduler_fused = get_hfta_lr_scheduler_for(lr_scheduler.StepLR, B=B)( optimizer_fused, step_size, gamma=gamma, last_epoch=last_epoch, ) _lr_scheduler_testing_procedure(net_fused, net_array, optimizer_fused, optimizer_array, lr_scheduler_fused, lr_scheduler_array)
def testcase_partially_fused( B=3, amsgrad=False, device=torch.device('cpu'), dtype=torch.float, ): kwargs = {'device': device, 'dtype': dtype} net_array = [_TestNet(**kwargs) for _ in range(B)] net_fused = _TestNet(B=B, partially_fused=True, **kwargs) lr = [random.uniform(1e-4, 1e-2) for _ in range(B)] betas = ( [random.uniform(0.8, 0.99) for _ in range(B)], [random.uniform(0.998, 0.9999) for _ in range(B)], ) eps = [random.uniform(1e-9, 1e-7) for _ in range(B)] weight_decay = [random.uniform(0.0, 0.3) for _ in range(B)] optimizer_array = [ optim.Adam( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), betas=( index_array_or_return_scalar(betas[0], b), index_array_or_return_scalar(betas[1], b), ), eps=index_array_or_return_scalar(eps, b), weight_decay=index_array_or_return_scalar(weight_decay, b), amsgrad=amsgrad, ) for b in range(B) ] partially_fused_optimizer = get_hfta_optim_for( optim.Adam, B=B, partially_fused=True, )( net_fused.parameters(), net_fused.unfused_parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, B=B, ) _optim_testing_procedure(net_fused, net_array, partially_fused_optimizer, optimizer_array)
def testcase_StepLR_partially_fused(B=3): net_array = [_TestNet() for _ in range(B)] net_fused = _TestNet(B=B, partially_fused=True) lr = [random.uniform(0.5, 2.0) for _ in range(B)] step_size = [random.randint(2, 8) for _ in range(B)] gamma = [random.uniform(0.1, 0.3) for _ in range(B)] last_epoch = [random.randint(5, 11) for _ in range(B)] optimizer_array = [ optim.Adadelta( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), ) for b in range(B) ] optimizer_partially_fused = get_hfta_optim_for( optim.Adadelta, B=B, partially_fused=True, )( net_fused.parameters(), net_fused.unfused_parameters(), lr=lr, ) _init_initial_lr(optimizer_partially_fused, optimizer_array) lr_scheduler_array = [ lr_scheduler.StepLR( optimizer_array[b], index_array_or_return_scalar(step_size, b), gamma=index_array_or_return_scalar(gamma, b), last_epoch=index_array_or_return_scalar(last_epoch, b), ) for b in range(B) ] lr_scheduler_partially_fused = get_hfta_lr_scheduler_for( lr_scheduler.StepLR, B=B, partially_fused=True, )( optimizer_partially_fused, step_size, gamma=gamma, last_epoch=last_epoch, ) _lr_scheduler_testing_procedure(net_fused, net_array, optimizer_fused, optimizer_array, lr_scheduler_fused, lr_scheduler_array)
def testcase_fused( B=3, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, device=torch.device('cpu'), dtype=torch.float, ): if B > 1 and isinstance(lr, (int, float)): lr = [random.uniform(1e-4, 1e-2) for _ in range(B)] kwargs = {'device': device, 'dtype': dtype} net_array = [_TestNet(**kwargs) for _ in range(B)] net_fused = _TestNet(B=B, **kwargs) optimizer_array = [ optim.Adam( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), betas=( index_array_or_return_scalar(betas[0], b), index_array_or_return_scalar(betas[1], b), ), eps=index_array_or_return_scalar(eps, b), weight_decay=index_array_or_return_scalar(weight_decay, b), amsgrad=amsgrad, ) for b in range(B) ] optimizer_fused = get_hfta_optim_for(optim.Adam, B=B)( net_fused.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, ) _optim_testing_procedure(net_fused, net_array, optimizer_fused, optimizer_array)
def testcase_partially_fused( B=3, device=torch.device('cpu'), dtype=torch.float, ): kwargs = {'device': device, 'dtype': dtype} net_array = [_TestNet(**kwargs) for _ in range(B)] net_fused = _TestNet(B=B, partially_fused=True, **kwargs) lr = [random.uniform(0.5, 2.0) for _ in range(B)] rho = [random.uniform(0.7, 0.99) for _ in range(B)] eps = [random.uniform(1e-7, 1e-5) for _ in range(B)] weight_decay = [random.uniform(0.0, 0.3) for _ in range(B)] optimizer_array = [ optim.Adadelta( net_array[b].parameters(), lr=index_array_or_return_scalar(lr, b), rho=index_array_or_return_scalar(rho, b), eps=index_array_or_return_scalar(eps, b), weight_decay=index_array_or_return_scalar(weight_decay, b), ) for b in range(B) ] partially_fused_optimizer = get_hfta_optim_for( optim.Adadelta, B=B, partially_fused=True, )( net_fused.parameters(), net_fused.unfused_parameters(), lr=lr, rho=rho, eps=eps, weight_decay=weight_decay, B=B, ) _optim_testing_procedure(net_fused, net_array, partially_fused_optimizer, optimizer_array)