def get_transformer(device: torch.device) -> GetterReturnType: # For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128. N = 64 seq_length = 128 ntoken = 50 model = models.TransformerModel(ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2) model.to(device) if has_functorch: # disable dropout for consistency checking model.eval() criterion = nn.NLLLoss() params, names = extract_weights(model) data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long() inputs = data.narrow(1, 0, seq_length) targets = data.narrow(1, 1, seq_length) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs) loss = criterion(out.reshape(N * seq_length, ntoken), targets.reshape(N * seq_length)) return loss return forward, params
def get_fcn_resnet(device: torch.device) -> GetterReturnType: N = 8 criterion = torch.nn.MSELoss() model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False) if has_functorch: from functorch.experimental import replace_all_batch_norm_modules_ replace_all_batch_norm_modules_(model) # disable dropout for consistency checking model.eval() model.to(device) params, names = extract_weights(model) inputs = torch.rand([N, 3, 480, 480], device=device) # Given model has 21 classes labels = torch.rand([N, 21, 480, 480], device=device) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs)['out'] loss = criterion(out, labels) return loss return forward, params
def get_deepspeech(device: torch.device) -> GetterReturnType: sample_rate = 16000 window_size = 0.02 window = "hamming" audio_conf = dict(sample_rate=sample_rate, window_size=window_size, window=window, noise_dir=None) N = 10 num_classes = 10 spectrogram_size = 161 # Commented are the original sizes in the code seq_length = 500 # 1343 target_length = 10 # 50 labels = torch.rand(num_classes, device=device) inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device) # Sequence length for each input inputs_sizes = torch.rand(N, device=device).mul(seq_length * 0.1).add( seq_length * 0.8) targets = torch.rand(N, target_length, device=device) targets_sizes = torch.full((N, ), target_length, dtype=torch.int, device=device) model = models.DeepSpeech(rnn_type=nn.LSTM, labels=labels, rnn_hidden_size=1024, nb_layers=5, audio_conf=audio_conf, bidirectional=True) if has_functorch: from functorch.experimental import replace_all_batch_norm_modules_ replace_all_batch_norm_modules_(model) model = model.to(device) criterion = nn.CTCLoss() params, names = extract_weights(model) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out, out_sizes = model(inputs, inputs_sizes) out = out.transpose(0, 1) # For ctc loss loss = criterion(out, targets, out_sizes, targets_sizes) return loss return forward, params
def get_resnet18(device: torch.device) -> GetterReturnType: N = 32 model = models.resnet18(pretrained=False) criterion = torch.nn.CrossEntropyLoss() model.to(device) params, names = extract_weights(model) inputs = torch.rand([N, 3, 224, 224], device=device) labels = torch.rand(N, device=device).mul(10).long() def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs) loss = criterion(out, labels) return loss return forward, params
def get_fcn_resnet(device: torch.device) -> GetterReturnType: N = 8 criterion = torch.nn.MSELoss() model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False) model.to(device) params, names = extract_weights(model) inputs = torch.rand([N, 3, 480, 480], device=device) # Given model has 21 classes labels = torch.rand([N, 21, 480, 480], device=device) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs)['out'] loss = criterion(out, labels) return loss return forward, params
def get_wav2letter(device: torch.device) -> GetterReturnType: N = 10 input_frames = 700 vocab_size = 28 model = models.Wav2Letter(num_classes=vocab_size) criterion = torch.nn.NLLLoss() model.to(device) params, names = extract_weights(model) inputs = torch.rand([N, 1, input_frames], device=device) labels = torch.rand(N, 3, device=device).mul(vocab_size).long() def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs) loss = criterion(out, labels) return loss return forward, params
def get_multiheadattn(device: torch.device) -> GetterReturnType: # From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10 embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module in_proj = models.InProjContainer( torch.nn.Linear(embed_dim, embed_dim, bias=False), torch.nn.Linear(embed_dim, embed_dim, bias=False), torch.nn.Linear(embed_dim, embed_dim, bias=False)) model = models.MultiheadAttentionContainer( nhead, in_proj, models.ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim, bias=False)) model.to(device) params, names = extract_weights(model) query = torch.rand((tgt_len, bsz, embed_dim), device=device) key = value = torch.rand((src_len, bsz, embed_dim), device=device) attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len), device=device).to(torch.bool) bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device) attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead)) bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) mha_output, attn_weights = model(query, key, value, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v) # Don't test any specific loss, just backprop ones for both outputs loss = mha_output.sum() + attn_weights.sum() return loss return forward, params
def get_resnet18(device: torch.device) -> GetterReturnType: N = 32 model = models.resnet18(pretrained=False) if has_functorch: from functorch.experimental import replace_all_batch_norm_modules_ replace_all_batch_norm_modules_(model) criterion = torch.nn.CrossEntropyLoss() model.to(device) params, names = extract_weights(model) inputs = torch.rand([N, 3, 224, 224], device=device) labels = torch.rand(N, device=device).mul(10).long() def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs) loss = criterion(out, labels) return loss return forward, params
def get_detr(device: torch.device) -> GetterReturnType: # All values below are from CLI defaults in https://github.com/facebookresearch/detr N = 2 num_classes = 91 hidden_dim = 256 nheads = 8 num_encoder_layers = 6 num_decoder_layers = 6 model = models.DETR(num_classes=num_classes, hidden_dim=hidden_dim, nheads=nheads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers) losses = ['labels', 'boxes', 'cardinality'] eos_coef = 0.1 bbox_loss_coef = 5 giou_loss_coef = 2 weight_dict = { 'loss_ce': 1, 'loss_bbox': bbox_loss_coef, 'loss_giou': giou_loss_coef } matcher = models.HungarianMatcher(1, 5, 2) criterion = models.SetCriterion(num_classes=num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=eos_coef, losses=losses) model = model.to(device) criterion = criterion.to(device) params, names = extract_weights(model) inputs = torch.rand(N, 3, 800, 1200, device=device) labels = [] for idx in range(N): targets = {} n_targets: int = int(torch.randint(5, 10, size=tuple()).item()) label = torch.randint(5, 10, size=(n_targets, )) targets["labels"] = label boxes = torch.randint(100, 800, size=(n_targets, 4)) for t in range(n_targets): if boxes[t, 0] > boxes[t, 2]: boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0] if boxes[t, 1] > boxes[t, 3]: boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1] targets["boxes"] = boxes.float() labels.append(targets) def forward(*new_params: Tensor) -> Tensor: load_weights(model, names, new_params) out = model(inputs) loss = criterion(out, labels) weight_dict = criterion.weight_dict final_loss = cast( Tensor, sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict)) return final_loss return forward, params
def update_one_step(data, learning_rate, epsilon, n): state, count, logphi0 = data['state'], data['count'], data['logphi0'] op_states, op_coeffs = data['update_states'], data['update_coeffs'] psi = logphi_model(state) logphi = psi[:, 0].reshape(len(state), -1) theta = psi[:, 1].reshape(len(state), -1) # calculate the weights of the energy from important sampling delta_logphi = logphi - logphi0[..., None] # delta_logphi = delta_logphi - delta_logphi.mean()*torch.ones(delta_logphi.shape) delta_logphi = delta_logphi - delta_logphi.mean() weights = count[..., None]*torch.exp(delta_logphi * 2) weights_norm = weights.sum() weights = (weights/weights_norm).detach() if weights_norm/count.sum() > target_wn: return weights_norm/count.sum(), 0 else: n_sample = op_states.shape[0] n_updates = op_states.shape[1] op_states = op_states.reshape([-1, Dp] + single_state_shape) psi_ops = logphi_model(op_states) logphi_ops = psi_ops[:, 0].reshape(n_sample, n_updates) theta_ops = psi_ops[:, 1].reshape(n_sample, n_updates) delta_logphi_os = logphi_ops - logphi*torch.ones_like(logphi_ops) # delta_logphi_os = torch.clamp(delta_logphi_os, max=5) delta_theta_os = theta_ops - theta*torch.ones_like(theta_ops) ops_real = torch.sum(op_coeffs*torch.exp(delta_logphi_os)*torch.cos(delta_theta_os), 1) ops_imag = torch.sum(op_coeffs*torch.exp(delta_logphi_os)*torch.sin(delta_theta_os), 1) with torch.no_grad(): ops = ops_real + 1j*ops_imag # batch_size mean_e = (ops_real[...,None]*weights).sum(0) # update parameters with gradient descent # copy the model parameters op_logphi_model = copy.deepcopy(logphi_model) params, names = extract_weights(op_logphi_model) def forward(*new_param): load_weights(op_logphi_model, names, new_param) out = op_logphi_model(state) return out dydws = jacobian(forward, params, vectorize=True) # a tuple contain all grads cnt = 0 tic = time.time() for param in logphi_model.parameters(): param_len = len(param.data.reshape(-1)) dydws_layer = dydws[cnt].reshape(n_sample,2,-1) with torch.no_grad(): grads_real = dydws_layer[:,0,:] # jacobian of d logphi wrt d w grads_imag = dydws_layer[:,1,:] # jacobian of d theta wrt d w Oks = grads_real + 1j*grads_imag Oks_conj = grads_real - 1j*grads_imag OO_matrix = Oks_conj.reshape(n_sample, 1, param_len)*Oks.reshape(n_sample, param_len, 1) Oks = Oks*weights Oks_conj = Oks_conj*weights Skk_matrix = (OO_matrix*weights[...,None]).sum(0) - Oks_conj.sum(0)[..., None]*Oks.sum(0) Skk_matrix = 0.5*(Skk_matrix + Skk_matrix.t().conj()) + epsilon*torch.eye(Skk_matrix.shape[0], device=gpu) # calculate Fk Fk = (ops[...,None]*Oks_conj).sum(0) - mean_e*(Oks_conj).sum(0) # update_k = torch.linalg.solve(Skk_matrix, Fk) update_k, _ = torch.solve(Fk[...,None], Skk_matrix) param.data -= learning_rate*update_k.real.reshape(param.data.shape) cnt += 1 t = time.time() - tic return weights_norm/count.sum(), t
def update_one_step(data, learning_rate, epsilon): state, count, logphi0 = data['state'], data['count'], data['logphi0'] op_states, op_coeffs = data['update_states'], data['update_coeffs'] # psi = logphi_model(state) psi = psi_model(state) logphi = psi[:, 0].reshape(len(state), -1) theta = psi[:, 1].reshape(len(state), -1) # calculate the weights of the energy from important sampling delta_logphi = logphi - logphi0[..., None] #print (delta_logphi.cpu().detach().numpy()) # delta_logphi = delta_logphi - delta_logphi.mean()*torch.ones(delta_logphi.shape) delta_logphi = delta_logphi - delta_logphi.mean() weights = count[..., None] * torch.exp(delta_logphi * 2) weights_norm = weights.sum() weights = (weights / weights_norm).detach() if weights_norm / count.sum() > target_wn: return weights_norm / count.sum(), 0 else: n_sample = op_states.shape[0] n_updates = op_states.shape[1] op_states = op_states.reshape([-1, Dp] + single_state_shape) # psi_ops = logphi_model(op_states) psi_ops = psi_model(op_states) logphi_ops = psi_ops[:, 0].reshape(n_sample, n_updates) theta_ops = psi_ops[:, 1].reshape(n_sample, n_updates) delta_logphi_os = logphi_ops - logphi * torch.ones_like(logphi_ops) # delta_logphi_os = torch.clamp(delta_logphi_os, max=5) delta_theta_os = theta_ops - theta * torch.ones_like(theta_ops) ops_real = torch.sum( op_coeffs * torch.exp(delta_logphi_os) * torch.cos(delta_theta_os), 1) ops_imag = torch.sum( op_coeffs * torch.exp(delta_logphi_os) * torch.sin(delta_theta_os), 1) with torch.no_grad(): ops = ops_real + 1j * ops_imag # batch_size mean_e = (ops_real[..., None] * weights).sum(0) # update parameters with gradient descent # copy the model parameters op_psi_model = copy.deepcopy(psi_model) params, names = extract_weights(op_psi_model) name_num = [] for name in names: name_num.append(int(name.split(".")[0])) _, cnts = np.unique(name_num, return_counts=True, axis=0) # net_layers = len(cnts) ws = psi_model.state_dict() def forward(*new_param): load_weights(op_psi_model, names, new_param) out = op_psi_model(state) return out dpsidws = jacobian(forward, params, vectorize=True) # a tuple contain all grads cnt = 0 tic = time.time() for net_layer in cnts: step = net_layer // 2 for i in range(step): index = cnt + i param_len = len(ws[names[index]].reshape(-1)) dres_layer = dpsidws[index].reshape(n_sample, 2, -1) dims_layer = dpsidws[index + step].reshape(n_sample, 2, -1) with torch.no_grad(): Oks = 0.5 * (dres_layer[:, 0, :] + dims_layer[:, 1, :] ) - 0.5 * 1j * (dims_layer[:, 0, :] - dres_layer[:, 1, :]) Oks_conj = Oks.conj() OO_matrix = Oks_conj.reshape( n_sample, 1, param_len) * Oks.reshape( n_sample, param_len, 1) Oks = Oks * weights Oks_conj = Oks_conj * weights Skk_matrix = (OO_matrix * weights[..., None]).sum( 0) - Oks_conj.sum(0)[..., None] * Oks.sum(0) Skk_matrix = 0.5 * (Skk_matrix + Skk_matrix.t().conj( )) + epsilon * torch.eye(Skk_matrix.shape[0], device=gpu) # calculate Fk Fk = (ops[..., None] * Oks_conj).sum(0) - mean_e * (Oks_conj).sum(0) update_k, _ = torch.solve(Fk[..., None], Skk_matrix) # real part ws[names[index]] -= learning_rate * update_k.real.reshape( ws[names[index]].shape) # imag part ws[names[index + step]] -= learning_rate * update_k.imag.reshape( ws[names[index]].shape) cnt += net_layer psi_model.load_state_dict(ws) t = time.time() - tic return weights_norm / count.sum(), t