class Q_ReLU6(Q_ReLU): def __init__(self, act_func=True, inplace=False): super(Q_ReLU6, self).__init__(act_func, inplace) def initialize(self, bits, offset, diff): self.bits = Parameter(Tensor(bits), requires_grad=False) self.nlvs = Parameter(torch.pow(2, self.bits), requires_grad=True) self.a = Parameter(Tensor(len(self.bits))) self.c = Parameter(Tensor(len(self.bits))) if offset + diff > 6: qmax = np.log(np.exp(6) - 1) stepsize = qmax / (self.nlvs.item() - 1) self.a.data.fill_(stepsize) self.c.data.fill_(stepsize) else: qmax = np.log(np.exp(offset + diff) - 1) stepsize = qmax / (self.nlvs.item() - 1) self.a.data.fill_(stepsize) self.c.data.fill_(stepsize) def initialize_qonly(self, offset, diff): if offset + diff > 6: self.a.data.fill_(np.log(np.exp(6) - 1)) self.c.data.fill_(np.log(np.exp(6) - 1)) else: self.a.data.fill_(np.log(np.exp(offset + diff) - 1)) self.c.data.fill_(np.log(np.exp(offset + diff) - 1))
class Q_Sym(nn.Module): def __init__(self): super(Q_Sym, self).__init__() self.bits = Parameter(Tensor([32])) self.nlvs = Parameter(Tensor([2**32])) self.a = Parameter(Tensor(1)) self.c = Parameter(Tensor(1)) def initialize(self, bits, offset, diff): self.bits = Parameter(Tensor(bits), requires_grad=False) self.nlvs = Parameter(torch.pow(2, self.bits), requires_grad=True) self.a = Parameter(Tensor(len(self.bits))) self.c = Parameter(Tensor(len(self.bits))) qmax = np.log(np.exp(offset + diff) - 1) stepsize = qmax / (self.nlvs.item() / 2 - 1) self.a.data.fill_(stepsize) self.c.data.fill_(stepsize) def initialize_qonly(self, offset, diff): self.a.data.fill_(np.log(np.exp(offset + diff) - 1)) self.c.data.fill_(np.log(np.exp(offset + diff) - 1)) def forward(self, x): if len(self.bits) == 1 and self.bits[0] == 32: return x else: a = F.softplus(self.a) c = F.softplus(self.c) nsteps = self.nlvs / 2 - 1 x = F.hardtanh(x / a, -nsteps, nsteps) x_bar = Round_fn.apply(x).div_(nsteps) * c return x_bar
class AddSine(AddCoords): def __init__(self, alpha=0.5, beta=None, phase_shift=0.): super(AddSine, self).__init__(False) if beta is None: beta = alpha self.alpha = Parameter(torch.FloatTensor([alpha])) self.beta = Parameter(torch.FloatTensor([beta])) self.phase = Parameter(torch.FloatTensor([phase_shift])) def generate_xy(self, input_tensor): batch_size, _, x_dim, y_dim = input_tensor.size() sx = self.phase xx_channel = torch.linspace(0., 1., x_dim).repeat(1, y_dim, 1).to(self.phase.device) yy_channel = torch.linspace(0., 1., y_dim).repeat(1, x_dim, 1).transpose( 1, 2).to(self.phase.device) xx_channel = xx_channel.float() * self.alpha yy_channel = yy_channel.float() * self.beta channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose( 2, 3) + yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) channel = torch.sin(channel + sx) return channel def __repr__(self): return self.__class__.__name__ + ' (' \ "alpha=" + str(self.alpha.item()) + \ " beta=" + str(self.beta.item()) + \ " phase=" + str(self.phase.item()) + ")" def forward(self, input_tensor): """ Args: input_tensor: shape(batch, channel, x_dim, y_dim) """ xx_channel = self.generate_xy(input_tensor) ret = torch.cat( [input_tensor, xx_channel.type_as(input_tensor)], dim=1) return ret
class ActQuantBuffers(ActQuant): # This class exist to allow multi-gpu run def __init__(self, quatize_during_training=False, noise_during_training=False, quant=False, bitwidth=32): super(ActQuantBuffers, self).__init__(quatize_during_training=quatize_during_training, noise_during_training=noise_during_training, quant=quant, bitwidth=bitwidth) self.register_buffer('running_mean', torch.zeros(1)) self.register_buffer('running_std', torch.zeros(1)) self.clamp_val = Parameter(torch.zeros(1), requires_grad=True) def forward(self, input): if self.pre_training_statistics: self.running_mean.to(input.device).detach().mul_( self.momentum).add_(input.mean() * (1 - self.momentum)) self.running_std.to(input.device).detach().mul_( self.momentum).add_(input.std() * (1 - self.momentum)) x = F.relu(input) elif self.quant and (not self.training or (self.training and self.quatize_during_training)): c_x = self.act_clamp(input, self.clamp_val) x = act_quant(c_x, self.clamp_val, self.bitwidth) else: if self.quant: x = self.act_clamp(input, self.clamp_val) else: x = F.relu(input) if not self.saved_stats and self.gather_stats: self.plot_statistic(x) self.saved_stats = True if False: self.print_clamp() return x def print_clamp(self): print('Activation layer {} has clamp value {}'.format( self.layer_num, self.clamp_val.item()))
class RecallAtPrecision(nn.Module): def __init__(self, input_size, hidden_size, alpha, dropout=0.0): super(RecallAtPrecision, self).__init__() self.device = torch.device('cuda') self.h1_weights = nn.Linear(input_size, hidden_size) self.h2_weights = nn.Linear(hidden_size, 2) self.dropout = dropout self.alpha = alpha self.alpha_term = alpha / (1 - alpha) self.lam = Parameter( torch.tensor([2.0], device=self.device, requires_grad=True)) self.result_dict = {} log.info('Optimize recall @ fixed precision=%.2f' % self.alpha) weights_init(self) def print_result_dict(self): TP = self.result_dict['true_pos'] FP = self.result_dict['false_pos'] NYP = self.result_dict['num_Y_pos'] TPL = self.result_dict['tp_lower'] FPU = self.result_dict['fp_upper'] precision, recall = TP / (TP + FP + 1e-10), TP / (NYP + 1e-10) if self.training: # log.info(self.h1_weights.weight) log.info('lambda = %.5f' % self.lam.item()) log.info('TP = %.1f(>=%.1f), FP = %.1f(<=%.1f), |Y+| = %.1f' % (TP, TPL, FP, FPU, NYP)) log.info('precision = %.5f, recall = %.5f' % (precision, recall)) log.info('inequality = %.5f(<=0)' % self.result_dict['inequality']) # recall_lb, precision_lb = TPL / (NP + 1e-5), TPL / (TPL + FPU + 1e-5) # log.info('R LB = %.5f, P LB = %.5f' % (recall_lb, precision_lb)) def forward(self, X, target=None): """ logits = f(X), target = Y in {0, 1} """ h1 = self.h1_weights(X) h1 = F.sigmoid(h1) h1 = F.dropout(h1, p=self.dropout, training=self.training) logits = self.h2_weights(h1) logits = F.softmax(logits, dim=1) pred_cls = (logits[:, 1] > logits[:, 0]).to(torch.int32) if target is None: return pred_cls target = target.to(torch.float32) y = 2 * target - 1 # y must in {-1, 1} L = 0.0 # pred belong to {-1, 1} # pred = (logits[:, 1] - self.bias) * 2 - 1 pred = (logits[:, 1] - logits[:, 0]) * 2 - 1 hinge_loss = torch.max(1 - y * pred, torch.tensor(0.0).to(y.device)) Lp = (hinge_loss * target).sum() Ln = (hinge_loss * (1 - target)).sum() # L = (1 + lam) * Lp + lam * alpha_term * Ln - lam * target.sum() L = Lp + self.lam * (self.alpha_term * Ln + Lp - target.sum()) # # pred_cls and pred_cls_float belong to {0.0, 1.0} pred_cls_float = (logits[:, 1] > logits[:, 0]).to(torch.float32) true_pos = (target * pred_cls_float).sum().item() false_pos = ((1 - target) * pred_cls_float).sum().item() num_Y_pos = target.sum().item() # NOT positive of predicition tp_lower = (num_Y_pos - Lp).item() fp_upper = Ln.item() inequality = self.alpha_term * Ln + Lp - self.lam * num_Y_pos keys = [ 'true_pos', 'false_pos', 'num_Y_pos', 'tp_lower', 'fp_upper', 'inequality', ] values = [ true_pos, false_pos, num_Y_pos, tp_lower, fp_upper, inequality, ] for key, value in zip(keys, values): self.result_dict[key] = value correct = pred_cls.eq(target.to(torch.int32).data.view_as(pred_cls)) accu = (correct.sum().item()) / float(correct.size(0)) if self.lam.requires_grad is True and self.training is True: L *= -1 return L, accu, pred_cls
class CompProbModel(torch.nn.Module): def __init__(self, a_max=7.25, s_max=9.25, avg_ball_speed=20.0, tti_sigma=0.5, tti_lambda_off=1.0, tti_lambda_def=1.0, ppc_alpha=1.0, tuning=None, use_ppc=False): super().__init__() # define self.tuning self.tuning = tuning # define parameters and whether or not to optimize self.tti_sigma = Parameter( torch.tensor([tti_sigma]), requires_grad=(self.tuning == TuningParam.sigma)).float() self.tti_lambda_off = Parameter( torch.tensor([tti_lambda_off]), requires_grad=(self.tuning == TuningParam.lamb)).float() self.tti_lambda_def = Parameter( torch.tensor([tti_lambda_def]), requires_grad=(self.tuning == TuningParam.lamb)).float() self.ppc_alpha = Parameter( torch.tensor([ppc_alpha]), requires_grad=(self.tuning == TuningParam.alpha)).float() self.a_max = Parameter( torch.tensor([a_max]), requires_grad=(self.tuning == TuningParam.av)).float() self.s_max = Parameter( torch.tensor([s_max]), requires_grad=(self.tuning == TuningParam.av)).float() self.reax_t = Parameter(torch.tensor([0.2])).float() self.avg_ball_speed = Parameter(torch.tensor([avg_ball_speed]), requires_grad=False).float() self.g = Parameter(torch.tensor([10.72468]), requires_grad=False) #y/s/s self.z_max = Parameter(torch.tensor([3.]), requires_grad=False) self.z_min = Parameter(torch.tensor([0.]), requires_grad=False) self.use_ppc = use_ppc self.zero_cuda = Parameter(torch.tensor([0.0], dtype=torch.float32), requires_grad=False) # define field grid self.x = torch.linspace(0.5, 119.5, 120).float() self.y = torch.linspace(-0.5, 53.5, 55).float() self.y[0] = -0.2 self.yy, self.xx = torch.meshgrid(self.y, self.x) self.field_locs = Parameter(torch.flatten(torch.stack( (self.xx, self.yy), dim=-1), end_dim=-2), requires_grad=False) # (F, 2) self.T = Parameter(torch.linspace(0.1, 4, 40), requires_grad=False) # (T,) # for hist trans prob self.hist_x_min, self.hist_x_max = -9, 70 self.hist_y_min, self.hist_y_max = -39, 40 self.hist_t_min, self.hist_t_max = 10, 63 self.T_given_Ls_df = pd.read_pickle('in/T_given_L.pkl') def get_hist_trans_prob(self, frame): B = len(frame) """ P(L|t) """ ball_start = frame[:, 0, 8:10] # (B, 2) ball_start_ind = torch.round(ball_start).long() reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze( 1) # (B, F, 2) # mask for zeroing out parts of the field that are too far to be thrown to per the L_given_t model L_t_mask = torch.zeros(B, *self.xx.shape) # (B, Y, X) b_zeros = torch.zeros(ball_start_ind.shape[0]) b_ones = torch.ones(ball_start_ind.shape[0]) for bb in range(B): L_t_mask[bb, max(0, ball_start_ind[bb,1]+self.hist_y_min):\ min(len(self.y)-1, ball_start_ind[bb,1]+self.hist_y_max),\ max(0, ball_start_ind[bb,0]+self.hist_x_min):\ min(len(self.x)-1, ball_start_ind[bb,0]+self.hist_x_max)] = 1. L_t_mask = L_t_mask.flatten(1) # (B, F) L_given_t = L_t_mask #changed L_given_t to uniform after discussion # renormalize since part of L|t may have been off field L_given_t /= L_given_t.sum(1, keepdim=True) # (B, F) """ P(T|L) """ # we find T|L for sufficiently close spots (1 < L <= 60) reach_dist_int = torch.round(torch.linalg.norm( reach_vecs, dim=-1)).long() # (B, F) reach_dist_in_bounds_idx = (reach_dist_int > 1) & (reach_dist_int <= 60) reach_dist_in_bounds = reach_dist_int[ reach_dist_in_bounds_idx] # 1d tensor T_given_L_subset = torch.from_numpy(self.T_given_Ls_df.set_index('pass_dist').loc[reach_dist_in_bounds, 'p'].to_numpy()).float()\ .reshape(-1, len(self.T)) # (BF~, T) ; BF~ is subset of B*F that is in [1, 60] yds from ball T_given_L = torch.zeros(B * len(self.field_locs), len(self.T)) # (B, F, T) # fill in the subset of values computed above T_given_L[reach_dist_in_bounds_idx.flatten()] = T_given_L_subset T_given_L = T_given_L.reshape(B, len(self.field_locs), -1) # (B, F, T) L_T_given_t = L_given_t[..., None] * T_given_L # (B, F, T) L_T_given_t /= L_T_given_t.sum( (1, 2), keepdim=True ) # normalize all passes after some have been chopped off return L_T_given_t # (B, F, T) def get_ppc_off(self, frame, p_int): assert self.use_ppc, 'Call made to get_ppc_off while use_ppc setting is False' B = frame.shape[0] J = p_int.shape[-1] ball_start = frame[:, 0, 8:10] # (B, 2) player_teams = frame[:, :, 7] # (B, J) reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze( 1) # B, F, 2 # trajectory integration dx = reach_vecs[:, :, 0] #B, F dy = reach_vecs[:, :, 1] #B, F vx = dx[:, :, None] / self.T[None, None, :] #F, T vy = dy[:, :, None] / self.T[None, None, :] #F, T vz_0 = (self.T * self.g) / 2 #T # note that idx (i, j, k) into below arrays is invalid when j < k traj_ts = self.T.repeat(len(self.field_locs), len(self.T), 1) #(F, T, T) traj_locs_x_idx = torch.round( torch.clip((ball_start[:, 0, None, None, None] + vx.unsqueeze(-1) * self.T), 0, len(self.x) - 1)).int() # B, F, T, T traj_locs_y_idx = torch.round( torch.clip((ball_start[:, 1, None, None, None] + vy.unsqueeze(-1) * self.T), 0, len(self.y) - 1)).int() # B, F, T, T traj_locs_z = 2.0 + vz_0.view( 1, -1, 1) * traj_ts - 0.5 * self.g * traj_ts * traj_ts #F, T, T lambda_z = torch.where( (traj_locs_z < self.z_max) & (traj_locs_z > self.z_min), 1, 0) #F, T, T path_idxs = (traj_locs_y_idx * self.x.shape[0] + traj_locs_x_idx).long().reshape(B, -1) # (B, F*T*T) # 10*traj_ts - 1 converts the times into indices - hacky traj_t_idxs = (10 * traj_ts - 1).long().repeat(B, 1, 1, 1).reshape( B, -1) # (B, F*T*T) p_int_traj = torch.stack([p_int[bb, path_idxs[bb], traj_t_idxs[bb], :] for bb in range(B)])\ .reshape(*traj_locs_x_idx.shape, -1) * lambda_z.unsqueeze(-1) # B, F, T, T, J p_int_traj_sum = p_int_traj.sum(dim=-1, keepdim=True) # B, F, T, T, J norm_factor = torch.maximum(torch.ones_like(p_int_traj_sum), p_int_traj_sum) # B, F, T, T p_int_traj_norm = p_int_traj / norm_factor # B, F, T, T, J # independent int probs at each point on trajectory all_p_int_traj = torch.sum(p_int_traj_norm, dim=-1) # B, F, T, T # off_p_int_traj = torch.sum((player_teams == 1)[:,None,None,None] * p_int_traj_norm, dim=-1) # B, F, T, T # def_p_int_traj = torch.sum((player_teams == 0)[:,None,None,None] * p_int_traj_norm, dim=-1) # B, F, T, T ind_p_int_traj = p_int_traj_norm #use for analyzing specific players; # B, F, T, T, J # calc decaying residual probs after you take away p_int on earlier times in the traj compl_all_p_int_traj = 1 - all_p_int_traj # B, F, T, T remaining_compl_p_int_traj = torch.cumprod(compl_all_p_int_traj, dim=-1) # B, F, T, T # maximum 0 because if it goes negative the pass has been caught by then and theres no residual probability shift_compl_cumsum = torch.roll(remaining_compl_p_int_traj, 1, dims=-1) # B, F, T, T shift_compl_cumsum[:, :, :, 0] = 1 # multiply residual prob by p_int at that location and lambda lambda_all = self.tti_lambda_off * player_teams + self.tti_lambda_def * ( 1 - player_teams) # B, J # off_completion_prob_dt = shift_compl_cumsum * off_p_int_traj # B, F, T, T # def_completion_prob_dt = shift_compl_cumsum * def_p_int_traj # B, F, T, T # all_completion_prob_dt = off_completion_prob_dt + def_completion_prob_dt # B, F, T, T ind_completion_prob_dt = shift_compl_cumsum.unsqueeze( -1) * ind_p_int_traj # F, T, T, J # now accumulate values over total traj for each team and take at T=t # all_completion_prob = torch.cumsum(all_completion_prob_dt, dim=-1) # B, F, T, T # off_completion_prob = torch.cumsum(off_completion_prob_dt, dim=-1) # B, F, T, T # def_completion_prob = torch.cumsum(def_completion_prob_dt, dim=-1) # B, F, T, T ind_completion_prob = torch.cumsum(ind_completion_prob_dt, dim=-2) # B, F, T, T, J # this einsum takes the diagonal values over the last two axes where T = t # this takes care of the t > T issue. # ppc_all = torch.einsum('...ii->...i', all_completion_prob) # B, F, T # ppc_off = torch.einsum('...ii->...i', off_completion_prob) # B, F, T # ppc_def = torch.einsum('...ii->...i', def_completion_prob) # B, F, T ppc_ind = torch.einsum('...iij->...ij', ind_completion_prob) # B, F, T, J ppc_ind *= lambda_all[:, None, None, :] # no_p_int_pass = 1-ppc_all # B, F, T ppc_off = torch.sum(ppc_ind * player_teams[:, None, None, :], dim=-1) # B, F, T ppc_def = torch.sum(ppc_ind * (1 - player_teams)[:, None, None, :], dim=-1) # B, F, T # assert torch.allclose(all_p_int_pass, off_p_int_pass + def_p_int_pass, atol=0.01) # assert torch.allclose(all_p_int_pass, ind_p_int_pass.sum(-1), atol=0.01) # return off_p_int_pass, def_p_int_pass, ind_p_int_pass return ppc_off, ppc_def, ppc_ind def forward(self, frame): v_x_r = frame[:, :, 5] * self.reax_t + frame[:, :, 3] v_y_r = frame[:, :, 6] * self.reax_t + frame[:, :, 4] v_r_mag = torch.norm(torch.stack([v_x_r, v_y_r], dim=-1), dim=-1) v_r_theta = torch.atan2(v_y_r, v_x_r) x_r = frame[:, :, 1] + frame[:, :, 3] * self.reax_t + 0.5 * frame[:, :, 5] * self.reax_t**2 y_r = frame[:, :, 2] + frame[:, :, 4] * self.reax_t + 0.5 * frame[:, :, 6] * self.reax_t**2 # get each player's team, location, and velocity player_teams = frame[:, :, 7] # B, J reaction_player_locs = torch.stack([x_r, y_r], dim=-1) # (J, 2) reaction_player_vels = torch.stack([v_x_r, v_y_r], dim=-1) #(J, 2) # calculate each player's distance from each field location int_d_vec = self.field_locs.unsqueeze(1).unsqueeze( 0) - reaction_player_locs.unsqueeze(1) #F, J, 2 int_d_mag = torch.norm(int_d_vec, dim=-1) # F, J int_d_theta = torch.atan2(int_d_vec[..., 1], int_d_vec[..., 0]) # F, J # take dot product of velocity and direction int_s0 = torch.clamp( torch.sum(int_d_vec * reaction_player_vels.unsqueeze(1), dim=-1) / int_d_mag, -1 * self.s_max.item(), self.s_max.item()) #F, J # calculate time it takes for each player to reach each field position accounting for their current velocity and acceleration t_lt_smax = (self.s_max - int_s0) / self.a_max #F, J, d_lt_smax = t_lt_smax * ((int_s0 + self.s_max) / 2) #F, J, # if accelerating would overshoot, then t = -v0/a + sqrt(v0^2/a^2 + 2x/a) (from kinematics) t_lt_smax = torch.where(d_lt_smax > int_d_mag, -int_s0 / self.a_max + \ torch.sqrt((int_s0 / self.a_max) ** 2 + 2 * int_d_mag / self.a_max), t_lt_smax) # F, J d_lt_smax = torch.max(torch.min(d_lt_smax, int_d_mag), torch.zeros_like(d_lt_smax)) # F, J d_at_smax = int_d_mag - d_lt_smax #F, J, t_at_smax = d_at_smax / self.s_max #F, J, t_tot = self.reax_t + t_lt_smax + t_at_smax # F, J, # get true pass (tof and ball_end) to tune on (subtract 1 from tof, add 1 to y for correct indexing) tof = torch.round(frame[:, 0, -1]).long().view(-1, 1, 1, 1).repeat( 1, t_tot.size(1), 1, t_tot.size(-1)) - 1 # ball ind ball_end_x = frame[:, 0, -3].int() ball_end_y = frame[:, 0, -2].int() + 1 ball_field_ind = (ball_end_y * self.x.shape[0] + ball_end_x).long().view(-1, 1, 1).repeat( 1, 1, t_tot.size(-1)) if self.tuning == TuningParam.av: # collapse extra dims tof = self.T[tof[:, 0, 0, 0]].float() # select field in for all the position and velocity values calculated previously t_lt_smax = torch.gather(t_lt_smax, 1, ball_field_ind).squeeze() # J, d_lt_smax = torch.gather(d_lt_smax, 1, ball_field_ind).squeeze() # J, d_at_smax = torch.gather(d_at_smax, 1, ball_field_ind).squeeze() t_at_smax = torch.gather(t_at_smax, 1, ball_field_ind).squeeze() t_tot = torch.gather(t_tot, 1, ball_field_ind).squeeze() int_s0 = torch.gather(int_s0, 1, ball_field_ind).squeeze() int_d_theta = torch.gather(int_d_theta, 1, ball_field_ind).squeeze() int_d_mag = torch.gather(int_d_mag, 1, ball_field_ind).squeeze() # projected locations at t = tof, f = ball_field_ind d_proj = torch.where(tof.unsqueeze(-1) <= self.reax_t, self.zero_cuda, torch.where(tof.unsqueeze(-1) <= (t_lt_smax + self.reax_t), (int_s0 * (tof.unsqueeze(-1) - self.reax_t)) + 0.5 * self.a_max \ * (tof.unsqueeze(-1) - self.reax_t) ** 2, torch.where(tof.unsqueeze(-1) <= (t_lt_smax + t_at_smax + self.reax_t), (d_lt_smax + (d_at_smax * (tof.unsqueeze(-1) - t_lt_smax - self.reax_t))), int_d_mag))) # J, d_proj = torch.minimum(d_proj, int_d_mag) x_proj = reaction_player_locs[..., 0] + d_proj * torch.cos( int_d_theta) # J y_proj = reaction_player_locs[..., 1] + d_proj * torch.sin( int_d_theta) # J # mask x_proj and y_proj (only want loss on closest off and def players) player_mask = frame[:, :, -4] masked_x = player_mask * x_proj masked_y = player_mask * y_proj return torch.stack([masked_x, masked_y], dim=-1) # J, 2 # subtract the arrival time (t_tot) from time of flight of ball int_dT = self.T.view(1, 1, -1, 1) - t_tot.unsqueeze(2) #F, T, J # calculate interception probability for each player, field loc, time of flight (logistic function) p_int = torch.sigmoid( (3.14 / (1.732 * self.tti_sigma)) * int_dT) # (B, F, T, J) if self.tuning == TuningParam.sigma: p_int = torch.gather(p_int, 2, tof).squeeze() p_int = torch.gather(p_int, 1, ball_field_ind).squeeze() return p_int elif self.tuning == TuningParam.alpha: h_trans_prob = self.get_hist_trans_prob(frame) # (B, F, T) if self.use_ppc: ppc_off, *_ = self.get_ppc_off(frame, p_int) trans_prob = h_trans_prob * torch.pow( ppc_off, self.ppc_alpha) # (B, F, T) else: # p_int summed over all offensive players p_int_off = torch.sum(p_int * (player_teams == 1), dim=-1) # (B, F, T) trans_prob = h_trans_prob * torch.pow(p_int_off, self.ppc_alpha) # (B,) trans_prob /= trans_prob.sum(dim=(1, 2), keepdim=True) # (B, F, T) # index into true pass. [...,0] necessary on indices because no J dimension trans_prob_throw = torch.gather(trans_prob, 2, tof[..., 0]).squeeze() trans_prob_throw = torch.gather( trans_prob_throw, 1, ball_field_ind[..., 0]).squeeze() # (B,) return trans_prob_throw elif self.tuning == TuningParam.lamb: assert self.use_ppc, 'need to use ppc to tune lambda' *_, ppc_ind = self.get_ppc_off(frame, p_int) # ppc_ind: (B, F, T, J) ppc_ind_throw = torch.gather(ppc_ind, 2, tof).squeeze() # B, F, J ppc_ind_throw = torch.gather(ppc_ind_throw, 1, ball_field_ind).squeeze() # B, J return ppc_ind_throw
class StochasticCNN(nn.Module): def __init__( self, num_training_samples: int, rnd: np.random.RandomState, num_last_units=100, trained_deterministic_model=None, prior_log_std=-3., catoni_lambda=1., delta=0.05, b=100, c=0.1, init_weights=True ): """ :param num_training_samples: the number of training data. :param rnd: `np.random.RandomState` instance for reproducibility, :param num_last_units: The size of unis of the last linear layer. :param prior_log_std: initial value of prior's log std value. :param catoni_lambda: Catoni's Lambda Parameter. It must be positive. :param delta: Confidence parameter. :param b: Prior variance's precision parameter. :param c: Prior variance's upper bound. :param init_weights: If true, weights are initialized by truncated Gaussian. Note this value must be called to calculate KL or chi-square divergence. """ upper_log_std = 0.5 * np.log(np.float32(c)) assert upper_log_std > prior_log_std, 'c is the upper bound of the prior\'s variance.' super(StochasticCNN, self).__init__() self.features = self.create_features() self.f_last = StochasticLinear(1600, num_last_units) self.num_weights = 0 self.prior_log_std = Parameter(torch.Tensor([prior_log_std])) self.previous_prior_log_std = torch.Tensor([prior_log_std]) # for constraints if init_weights: if trained_deterministic_model is not None: self._initialize_weights_from_deterministic_model(trained_deterministic_model) else: self._initialize_weights(rnd) self.num_training_samples = Parameter(torch.Tensor([num_training_samples]), requires_grad=False) self.delta = Parameter(torch.Tensor([delta]), requires_grad=False) self.b = Parameter(torch.Tensor([b]), requires_grad=False) self.c = Parameter(torch.Tensor([c]), requires_grad=False) self.catoni_lambda = Parameter(torch.Tensor([catoni_lambda]), requires_grad=False) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.features(x) x = torch.flatten(x, 1) x = self.f_last(x) return x def sample_noise(self) -> None: """ Sample weights and biases from the posterior. :return: None """ for m in self.modules(): if isinstance(m, (StochasticLinear, StochasticConv2D)): m.sample_noise() @staticmethod def create_features() -> torch.nn.modules.container.Sequential: return nn.Sequential( StochasticConv2D(3, 64, kernel_size=5, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), StochasticConv2D(64, 64, kernel_size=5, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) def _initialize_weights_from_deterministic_model(self, deterministic_model) -> None: for m, premodel_module in zip(self.modules(), deterministic_model.modules()): if isinstance(m, (StochasticLinear, StochasticConv2D)): m.weight.data.copy_(premodel_module.weight.data) m.weight_prior.data.copy_(premodel_module.weight.data) m.bias.data.copy_(premodel_module.bias.data) m.bias_prior.data.copy_(premodel_module.bias.data) nn.init.constant_(m.weight_log_std, self.prior_log_std.item()) nn.init.constant_(m.bias_log_std, self.prior_log_std.item()) self.num_weights += np.prod(m.weight.size()) self.num_weights += np.prod(m.bias.size()) def _initialize_weights(self, rnd: np.random.RandomState) -> None: conv_upper = 2 * 5e-2 for m in self.modules(): if isinstance(m, StochasticLinear): m.weight_prior.data = torch.from_numpy( truncnorm.rvs( a=-1. / 800., b=1. / 800., size=(100, 1600), random_state=rnd ).astype(np.float32) ) elif isinstance(m, StochasticConv2D): m.weight_prior.data = torch.from_numpy( truncnorm.rvs( a=-conv_upper, b=conv_upper, size=tuple(m.weight.size()), random_state=rnd ).astype(np.float32) ) if isinstance(m, (StochasticLinear, StochasticConv2D)): m.weight.data.copy_(m.weight_prior.data) nn.init.constant_(m.bias_prior, 0.) nn.init.constant_(m.bias, 0.) nn.init.constant_(m.weight_log_std, self.prior_log_std.item()) nn.init.constant_(m.bias_log_std, self.prior_log_std.item()) self.num_weights += np.prod(m.weight.size()) self.num_weights += np.prod(m.bias.size()) def kl(self, prior_log_std) -> torch.FloatTensor: """ Calculate KL divergence between posterior and prior. :param prior_log_std: :return: KL divergence value. """ num_weights = self.num_weights assert num_weights > 0 mean_norm_list = [] log_std_sum_list = [] variance_l1_norm_list = [] prior_variance = torch.exp(2. * prior_log_std) prior_log_variance = 2. * prior_log_std for m in self.modules(): if isinstance(m, (StochasticLinear, StochasticConv2D)): mean_norm_list.append(torch.sum((m.weight - m.weight_prior) ** 2)) mean_norm_list.append(torch.sum((m.bias - m.bias_prior) ** 2)) log_std_sum_list.append(torch.sum(m.weight_log_std)) log_std_sum_list.append(torch.sum(m.bias_log_std)) # negative `prior` term provides more accurate a part of term in KL than # `torch.exp(2. * m.bias_log_std)` then is divided by `prior_variance` variance_l1_norm_list.append(torch.sum(torch.exp(2. * m.weight_log_std - prior_log_variance))) variance_l1_norm_list.append(torch.sum(torch.exp(2. * m.bias_log_std - prior_log_variance))) norm_weights = torch.sum(torch.stack(mean_norm_list)) mean_part = norm_weights / prior_variance norm_log_std = 2. * torch.sum(torch.stack(log_std_sum_list)) sum_variance_l1_norm = torch.sum(torch.stack(variance_l1_norm_list)) std_part = sum_variance_l1_norm - norm_log_std + 2. * num_weights * prior_log_std kl = 0.5 * (mean_part + std_part - num_weights) return kl def union_bound(self, prior_log_std) -> torch.FloatTensor: """ Calculate union bound value related to prior. :param prior_log_std: Float parameter contains prior's log std. :return: FloatTensor """ return 2. * torch.log(self.b) + 2. * torch.log(torch.log(self.c) - 2. * prior_log_std) \ + torch.log(pi ** 2 / (6. * self.delta)) def pac_bayes_objective(self, contrastive_loss: torch.FloatTensor) -> tuple: """ Catoni's PAC-Bayes bound with union bound :param contrastive_loss: empirical risk: FloatTensor :return: PAC-Bayes upper bound, KL, and complexity term; FloatTensors """ kl = self.kl(prior_log_std=self.prior_log_std) # union bound term union_bound = self.union_bound(prior_log_std=self.prior_log_std) # KL term easily becomes large, so it is divided by Catoni's lambda objective = contrastive_loss + (kl + union_bound) / self.catoni_lambda return objective, kl, union_bound def compute_complexity_terms_with_discretized_prior_variance(self) -> tuple: """ Compute kl divergence and union bound terms by using discretized_prior_variance. Note `log (2 \sqrt{m})` is added to the union bound term in `contrastive.eval.common.pb_parameter_selection`. :return: tuple of kl divergence and union bound without sqrt{m}. """ # discretize prior's variance parameter # https://github.com/gkdziugaite/pacbayes-opt/blob/master/snn/core/network.py#L398 discretized_j = (self.b * (torch.log(self.c) - 2. * self.prior_log_std)) discretized_j_up = torch.ceil(discretized_j) discretized_j_down = torch.floor(discretized_j) constant_in_log_delta = torch.log(np.pi ** 2 / (6 * self.delta)) union_up = (constant_in_log_delta + 2 * torch.log(discretized_j_up)).item() union_down = (constant_in_log_delta + 2 * torch.log(discretized_j_down)).item() prior_log_std_up = (torch.log(self.c) - discretized_j_up / self.b) / 2. prior_log_std_down = (torch.log(self.c) - discretized_j_down / self.b) / 2. kl_up = self.kl(prior_log_std_up).item() kl_down = self.kl(prior_log_std_down).item() up_complexity = kl_up + union_up down_complexity = kl_down + union_down if up_complexity < down_complexity or np.isinf(down_complexity): return kl_up, union_up else: return kl_down, union_down def deterministic(self) -> None: """ Set mean values to weights for feed forwarding. :return: None """ for m in self.modules(): if isinstance(m, (StochasticLinear, StochasticConv2D)): m.realised_weight = m.weight.detach() m.realised_bias = m.bias.detach() def constraint(self) -> None: """ Constraint for prior's variance. :return: None """ if (torch.log(self.c) - 2. * self.prior_log_std).item() > 0.: self.previous_prior_log_std.data.copy_(self.prior_log_std.data) else: self.prior_log_std.data.copy_(self.previous_prior_log_std.data)
class _LearnableFakeQuantize(nn.Module): r""" This is an extension of the FakeQuantize module in fake_quantize.py, which supports more generalized lower-bit quantization and support learning of the scale and zero point parameters through backpropagation. For literature references, please see the class _LearnableFakeQuantizePerTensorOp. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize module also includes the following attributes to support quantization parameter learning. * :attr: `channel_len` defines the length of the channel when initializing scale and zero point for the per channel case. * :attr: `use_grad_scaling` defines the flag for whether the gradients for scale and zero point are normalized by the constant, which is proportional to the square root of the number of elements in the tensor. The related literature justifying the use of this particular constant can be found here: https://openreview.net/pdf?id=rkgO66VKDS. * :attr: `fake_quant_enabled` defines the flag for enabling fake quantization on the output. * :attr: `static_enabled` defines the flag for using observer's static estimation for scale and zero point. * attr: `learning_enabled` defines the flag for enabling backpropagation for scale and zero point. """ def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1, use_grad_scaling=False, **observer_kwargs): super(_LearnableFakeQuantize, self).__init__() assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' self.quant_min = quant_min self.quant_max = quant_max self.use_grad_scaling = use_grad_scaling if channel_len == -1: self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer." self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ 'quant_min out of bound' assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ 'quant_max out of bound' self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = self.activation_post_process.ch_axis \ if hasattr(self.activation_post_process, 'ch_axis') else -1 self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8)) bitrange = torch.tensor(quant_max - quant_min + 1).double() self.bitwidth = int(torch.log2(bitrange).item()) @torch.jit.export def enable_param_learning(self): r"""Enables learning of quantization parameters and disables static observer estimates. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=True) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=False) return self @torch.jit.export def enable_static_estimate(self): r"""Enables static observer estimates and disbales learning of quantization parameters. Forward path returns fake quantized X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=True) \ .toggle_observer_update(enabled=True) @torch.jit.export def enable_static_observation(self): r"""Enables static observer accumulating data from input but doesn't update the quantization parameters. Forward path returns the original X. """ self.toggle_qparam_learning(enabled=False) \ .toggle_fake_quant(enabled=False) \ .toggle_observer_update(enabled=True) @torch.jit.export def toggle_observer_update(self, enabled=True): self.static_enabled[0] = int(enabled) return self @torch.jit.export def toggle_qparam_learning(self, enabled=True): self.learning_enabled[0] = int(enabled) self.scale.requires_grad = enabled self.zero_point.requires_grad = enabled return self @torch.jit.export def toggle_fake_quant(self, enabled=True): self.fake_quant_enabled[0] = int(enabled) return self @torch.jit.export def observe_quant_params(self): print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) print('_LearnableFakeQuantize Zero Point: {}'.format(self.zero_point.detach())) @torch.jit.export def calculate_qparams(self): return self.activation_post_process.calculate_qparams() def forward(self, X): self.activation_post_process(X.detach()) _scale, _zero_point = self.calculate_qparams() _scale = _scale.to(self.scale.device) _zero_point = _zero_point.to(self.zero_point.device) if self.static_enabled[0] == 1: self.scale.data.copy_(_scale) self.zero_point.data.copy_(_zero_point) if self.fake_quant_enabled[0] == 1: if self.learning_enabled[0] == 1: if self.use_grad_scaling: grad_factor = 1.0 / (self.weight.numel() * self.quant_max) ** 0.5 else: grad_factor = 1.0 if self.qscheme in ( torch.per_channel_symmetric, torch.per_channel_affine): X = _LearnableFakeQuantizePerChannelOp.apply( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max, grad_factor) else: X = _LearnableFakeQuantizePerTensorOp.apply( X, self.scale, self.zero_point, self.quant_min, self.quant_max, grad_factor) else: if self.qscheme == torch.per_channel_symmetric or \ self.qscheme == torch.per_channel_affine: X = torch.fake_quantize_per_channel_affine( X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( X, float(self.scale.item()), int(self.zero_point.item()), self.quant_min, self.quant_max) return X def _save_to_state_dict(self, destination, prefix, keep_vars): # We will be saving the static state of scale (instead of as a dynamic param). super(_LearnableFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'scale'] = self.scale.data destination[prefix + 'zero_point'] = self.zero_point def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): local_state = ['scale', 'zero_point'] for name in local_state: key = prefix + name if key in state_dict: val = state_dict[key] if name == 'scale': self.scale.data.copy_(val) else: setattr(self, name, val) elif strict: missing_keys.append(key) super(_LearnableFakeQuantize, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) with_args = classmethod(_with_args)
class GPRegressor(nn.Module): def __init__(self, kernel, sn=0.1, lr=1e-1, scheduler=False, prior=True): super(GPRegressor, self).__init__() self.sn = Parameter(torch.Tensor([sn])) self.kernel = kernel self.loss_func = NLMLLoss() opt = [p for p in self.parameters() if p.requires_grad] self.optimizer = optim.Adam(opt, lr=lr) if prior: self.prior = torch.distributions.Beta(2, 2).log_prob else: self.prior = None if scheduler: self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=2, verbose=True, mode='max') else: self.scheduler = None def loss(self, X, y, jitter, val=None): K = self.kernel(X, X) inds = list(range(len(K))) K[[inds], [inds]] += self.sn + jitter L = torch.potrf(K, upper=False) alpha = torch.trtrs(y, L, upper=False)[0] alpha = torch.trtrs(alpha, L.t(), upper=True)[0] loss = self.loss_func(L, alpha, y) if self.prior is not None: loss -= self.prior(self.sn) if val is not None: X_val, y_val = val k_star = self.kernel(X, X_val) mu = k_star.t() @ alpha mse = nn.MSELoss()(mu, y_val) return loss, mse else: return loss def forward(self, X): """ Gaussian process regression predictions. Parameters: X: m x d points to predict Returns: mu: m x 1 predicted means var: m x m predicted covariance Follows Algorithm 2.1 from GPML. """ ### Implement prior ### ### Scaling k_star = self.kernel(self.X, X) mu = k_star.t() @ self.alpha v = torch.trtrs(k_star, self.L, upper=False)[0] k_ss = self.kernel(X, X) var = k_ss - v.t() @ v return mu, var def fit(self, X, y, its=100, jitter=1e-6, verbose=True, val=None, chkpt=None): self.X = X self.y = y self._fit(X, y, its, jitter, verbose, val, chkpt) self._set_pars(jitter) return self.history def _fit(self, X, y, its, jitter, verbose, val, chkpt): self.history = [] if val is not None and chkpt is not None: best_mse = 1e14 for it in range(its): if val is not None: loss, mse = self.loss(X, y, jitter, val=val) mse = mse.item() if chkpt is not None and mse < best_mse: torch.save(self.state_dict(), chkpt) else: loss = self.loss(X, y, jitter) # backward self.optimizer.zero_grad() loss.backward(retain_graph=False) # update parameters self.optimizer.step() self.sn.data.clamp_(min=1e-6) # if self.scheduler is not None: # self.scheduler.step(loss) if verbose: update = '\rIteration %d of %d\tNLML: %.4f\tsn: %.6f\t' \ %(it + 1, its, loss, self.sn.cpu().detach().numpy()[0]) print(update, end='') if val is not None: print('val mse: %.4f' % mse, end='') if val is None: h = (loss.item(), self.sn.item()) else: h = (loss.item(), self.sn.item(), mse) del mse self.history.append(h) del loss def _set_pars(self, jitter): Ky = self.kernel(self.X, self.X) inds = list(range(len(Ky))) Ky[[inds], [inds]] += self.sn + jitter self.L = torch.potrf(Ky, upper=False) self.alpha = torch.trtrs(self.y, self.L, upper=False)[0] self.alpha = torch.trtrs(self.alpha, self.L.t(), upper=True)[0]