Beispiel #1
0
class Q_ReLU6(Q_ReLU):
    def __init__(self, act_func=True, inplace=False):
        super(Q_ReLU6, self).__init__(act_func, inplace)

    def initialize(self, bits, offset, diff):
        self.bits = Parameter(Tensor(bits), requires_grad=False)
        self.nlvs = Parameter(torch.pow(2, self.bits), requires_grad=True)
        self.a = Parameter(Tensor(len(self.bits)))
        self.c = Parameter(Tensor(len(self.bits)))

        if offset + diff > 6:
            qmax = np.log(np.exp(6) - 1)
            stepsize = qmax / (self.nlvs.item() - 1)
            self.a.data.fill_(stepsize)
            self.c.data.fill_(stepsize)
        else:
            qmax = np.log(np.exp(offset + diff) - 1)
            stepsize = qmax / (self.nlvs.item() - 1)
            self.a.data.fill_(stepsize)
            self.c.data.fill_(stepsize)

    def initialize_qonly(self, offset, diff):
        if offset + diff > 6:
            self.a.data.fill_(np.log(np.exp(6) - 1))
            self.c.data.fill_(np.log(np.exp(6) - 1))
        else:
            self.a.data.fill_(np.log(np.exp(offset + diff) - 1))
            self.c.data.fill_(np.log(np.exp(offset + diff) - 1))
Beispiel #2
0
class Q_Sym(nn.Module):
    def __init__(self):
        super(Q_Sym, self).__init__()
        self.bits = Parameter(Tensor([32]))
        self.nlvs = Parameter(Tensor([2**32]))

        self.a = Parameter(Tensor(1))
        self.c = Parameter(Tensor(1))

    def initialize(self, bits, offset, diff):
        self.bits = Parameter(Tensor(bits), requires_grad=False)
        self.nlvs = Parameter(torch.pow(2, self.bits), requires_grad=True)
        self.a = Parameter(Tensor(len(self.bits)))
        self.c = Parameter(Tensor(len(self.bits)))

        qmax = np.log(np.exp(offset + diff) - 1)
        stepsize = qmax / (self.nlvs.item() / 2 - 1)
        self.a.data.fill_(stepsize)
        self.c.data.fill_(stepsize)

    def initialize_qonly(self, offset, diff):
        self.a.data.fill_(np.log(np.exp(offset + diff) - 1))
        self.c.data.fill_(np.log(np.exp(offset + diff) - 1))

    def forward(self, x):
        if len(self.bits) == 1 and self.bits[0] == 32:
            return x
        else:
            a = F.softplus(self.a)
            c = F.softplus(self.c)
            nsteps = self.nlvs / 2 - 1
            x = F.hardtanh(x / a, -nsteps, nsteps)
            x_bar = Round_fn.apply(x).div_(nsteps) * c

            return x_bar
Beispiel #3
0
class AddSine(AddCoords):
    def __init__(self, alpha=0.5, beta=None, phase_shift=0.):
        super(AddSine, self).__init__(False)
        if beta is None:
            beta = alpha
        self.alpha = Parameter(torch.FloatTensor([alpha]))
        self.beta = Parameter(torch.FloatTensor([beta]))
        self.phase = Parameter(torch.FloatTensor([phase_shift]))

    def generate_xy(self, input_tensor):
        batch_size, _, x_dim, y_dim = input_tensor.size()

        sx = self.phase

        xx_channel = torch.linspace(0., 1.,
                                    x_dim).repeat(1, y_dim,
                                                  1).to(self.phase.device)
        yy_channel = torch.linspace(0.,
                                    1., y_dim).repeat(1, x_dim, 1).transpose(
                                        1, 2).to(self.phase.device)

        xx_channel = xx_channel.float() * self.alpha
        yy_channel = yy_channel.float() * self.beta

        channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(
            2, 3) + yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        channel = torch.sin(channel + sx)
        return channel

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
                                         "alpha=" + str(self.alpha.item()) + \
               " beta=" + str(self.beta.item()) + \
               " phase=" + str(self.phase.item()) + ")"

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """

        xx_channel = self.generate_xy(input_tensor)
        ret = torch.cat(
            [input_tensor, xx_channel.type_as(input_tensor)], dim=1)

        return ret
class ActQuantBuffers(ActQuant):  # This class exist to allow multi-gpu run
    def __init__(self,
                 quatize_during_training=False,
                 noise_during_training=False,
                 quant=False,
                 bitwidth=32):
        super(ActQuantBuffers,
              self).__init__(quatize_during_training=quatize_during_training,
                             noise_during_training=noise_during_training,
                             quant=quant,
                             bitwidth=bitwidth)

        self.register_buffer('running_mean', torch.zeros(1))
        self.register_buffer('running_std', torch.zeros(1))
        self.clamp_val = Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, input):
        if self.pre_training_statistics:

            self.running_mean.to(input.device).detach().mul_(
                self.momentum).add_(input.mean() * (1 - self.momentum))

            self.running_std.to(input.device).detach().mul_(
                self.momentum).add_(input.std() * (1 - self.momentum))

            x = F.relu(input)

        elif self.quant and (not self.training or
                             (self.training and self.quatize_during_training)):

            c_x = self.act_clamp(input, self.clamp_val)
            x = act_quant(c_x, self.clamp_val, self.bitwidth)

        else:
            if self.quant:
                x = self.act_clamp(input, self.clamp_val)
            else:
                x = F.relu(input)

            if not self.saved_stats and self.gather_stats:
                self.plot_statistic(x)
                self.saved_stats = True

        if False:
            self.print_clamp()

        return x

    def print_clamp(self):
        print('Activation layer {}  has clamp value {}'.format(
            self.layer_num, self.clamp_val.item()))
Beispiel #5
0
class RecallAtPrecision(nn.Module):
    def __init__(self, input_size, hidden_size, alpha, dropout=0.0):
        super(RecallAtPrecision, self).__init__()

        self.device = torch.device('cuda')
        self.h1_weights = nn.Linear(input_size, hidden_size)
        self.h2_weights = nn.Linear(hidden_size, 2)
        self.dropout = dropout

        self.alpha = alpha
        self.alpha_term = alpha / (1 - alpha)
        self.lam = Parameter(
            torch.tensor([2.0], device=self.device, requires_grad=True))
        self.result_dict = {}
        log.info('Optimize recall @ fixed precision=%.2f' % self.alpha)

        weights_init(self)

    def print_result_dict(self):
        TP = self.result_dict['true_pos']
        FP = self.result_dict['false_pos']
        NYP = self.result_dict['num_Y_pos']
        TPL = self.result_dict['tp_lower']
        FPU = self.result_dict['fp_upper']
        precision, recall = TP / (TP + FP + 1e-10), TP / (NYP + 1e-10)

        if self.training:
            # log.info(self.h1_weights.weight)
            log.info('lambda = %.5f' % self.lam.item())

        log.info('TP = %.1f(>=%.1f), FP = %.1f(<=%.1f), |Y+| = %.1f' %
                 (TP, TPL, FP, FPU, NYP))
        log.info('precision = %.5f, recall = %.5f' % (precision, recall))
        log.info('inequality = %.5f(<=0)' % self.result_dict['inequality'])
        # recall_lb, precision_lb = TPL / (NP + 1e-5), TPL / (TPL + FPU + 1e-5)
        # log.info('R LB = %.5f, P LB = %.5f' % (recall_lb, precision_lb))

    def forward(self, X, target=None):
        """
        logits = f(X), target = Y in {0, 1}
        """
        h1 = self.h1_weights(X)
        h1 = F.sigmoid(h1)
        h1 = F.dropout(h1, p=self.dropout, training=self.training)

        logits = self.h2_weights(h1)
        logits = F.softmax(logits, dim=1)
        pred_cls = (logits[:, 1] > logits[:, 0]).to(torch.int32)

        if target is None:
            return pred_cls

        target = target.to(torch.float32)
        y = 2 * target - 1  # y must in {-1, 1}
        L = 0.0

        # pred belong to {-1, 1}
        # pred = (logits[:, 1] - self.bias) * 2 - 1
        pred = (logits[:, 1] - logits[:, 0]) * 2 - 1
        hinge_loss = torch.max(1 - y * pred, torch.tensor(0.0).to(y.device))
        Lp = (hinge_loss * target).sum()
        Ln = (hinge_loss * (1 - target)).sum()
        # L = (1 + lam) * Lp + lam * alpha_term * Ln - lam * target.sum()
        L = Lp + self.lam * (self.alpha_term * Ln + Lp - target.sum())

        # # pred_cls and pred_cls_float belong to {0.0, 1.0}
        pred_cls_float = (logits[:, 1] > logits[:, 0]).to(torch.float32)
        true_pos = (target * pred_cls_float).sum().item()
        false_pos = ((1 - target) * pred_cls_float).sum().item()
        num_Y_pos = target.sum().item()  # NOT positive of predicition
        tp_lower = (num_Y_pos - Lp).item()
        fp_upper = Ln.item()
        inequality = self.alpha_term * Ln + Lp - self.lam * num_Y_pos

        keys = [
            'true_pos',
            'false_pos',
            'num_Y_pos',
            'tp_lower',
            'fp_upper',
            'inequality',
        ]
        values = [
            true_pos,
            false_pos,
            num_Y_pos,
            tp_lower,
            fp_upper,
            inequality,
        ]
        for key, value in zip(keys, values):
            self.result_dict[key] = value

        correct = pred_cls.eq(target.to(torch.int32).data.view_as(pred_cls))
        accu = (correct.sum().item()) / float(correct.size(0))

        if self.lam.requires_grad is True and self.training is True:
            L *= -1

        return L, accu, pred_cls
class CompProbModel(torch.nn.Module):
    def __init__(self,
                 a_max=7.25,
                 s_max=9.25,
                 avg_ball_speed=20.0,
                 tti_sigma=0.5,
                 tti_lambda_off=1.0,
                 tti_lambda_def=1.0,
                 ppc_alpha=1.0,
                 tuning=None,
                 use_ppc=False):
        super().__init__()
        # define self.tuning
        self.tuning = tuning

        # define parameters and whether or not to optimize
        self.tti_sigma = Parameter(
            torch.tensor([tti_sigma]),
            requires_grad=(self.tuning == TuningParam.sigma)).float()
        self.tti_lambda_off = Parameter(
            torch.tensor([tti_lambda_off]),
            requires_grad=(self.tuning == TuningParam.lamb)).float()
        self.tti_lambda_def = Parameter(
            torch.tensor([tti_lambda_def]),
            requires_grad=(self.tuning == TuningParam.lamb)).float()
        self.ppc_alpha = Parameter(
            torch.tensor([ppc_alpha]),
            requires_grad=(self.tuning == TuningParam.alpha)).float()
        self.a_max = Parameter(
            torch.tensor([a_max]),
            requires_grad=(self.tuning == TuningParam.av)).float()
        self.s_max = Parameter(
            torch.tensor([s_max]),
            requires_grad=(self.tuning == TuningParam.av)).float()
        self.reax_t = Parameter(torch.tensor([0.2])).float()
        self.avg_ball_speed = Parameter(torch.tensor([avg_ball_speed]),
                                        requires_grad=False).float()
        self.g = Parameter(torch.tensor([10.72468]),
                           requires_grad=False)  #y/s/s
        self.z_max = Parameter(torch.tensor([3.]), requires_grad=False)
        self.z_min = Parameter(torch.tensor([0.]), requires_grad=False)
        self.use_ppc = use_ppc
        self.zero_cuda = Parameter(torch.tensor([0.0], dtype=torch.float32),
                                   requires_grad=False)

        # define field grid
        self.x = torch.linspace(0.5, 119.5, 120).float()
        self.y = torch.linspace(-0.5, 53.5, 55).float()
        self.y[0] = -0.2
        self.yy, self.xx = torch.meshgrid(self.y, self.x)
        self.field_locs = Parameter(torch.flatten(torch.stack(
            (self.xx, self.yy), dim=-1),
                                                  end_dim=-2),
                                    requires_grad=False)  # (F, 2)
        self.T = Parameter(torch.linspace(0.1, 4, 40),
                           requires_grad=False)  # (T,)

        # for hist trans prob
        self.hist_x_min, self.hist_x_max = -9, 70
        self.hist_y_min, self.hist_y_max = -39, 40
        self.hist_t_min, self.hist_t_max = 10, 63
        self.T_given_Ls_df = pd.read_pickle('in/T_given_L.pkl')

    def get_hist_trans_prob(self, frame):
        B = len(frame)
        """ P(L|t) """
        ball_start = frame[:, 0, 8:10]  # (B, 2)
        ball_start_ind = torch.round(ball_start).long()
        reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze(
            1)  # (B, F, 2)
        # mask for zeroing out parts of the field that are too far to be thrown to per the L_given_t model
        L_t_mask = torch.zeros(B, *self.xx.shape)  # (B, Y, X)
        b_zeros = torch.zeros(ball_start_ind.shape[0])
        b_ones = torch.ones(ball_start_ind.shape[0])
        for bb in range(B):
            L_t_mask[bb, max(0, ball_start_ind[bb,1]+self.hist_y_min):\
                        min(len(self.y)-1, ball_start_ind[bb,1]+self.hist_y_max),\
                     max(0, ball_start_ind[bb,0]+self.hist_x_min):\
                        min(len(self.x)-1, ball_start_ind[bb,0]+self.hist_x_max)] = 1.
        L_t_mask = L_t_mask.flatten(1)  # (B, F)
        L_given_t = L_t_mask  #changed L_given_t to uniform after discussion
        # renormalize since part of L|t may have been off field
        L_given_t /= L_given_t.sum(1, keepdim=True)  # (B, F)
        """ P(T|L) """
        # we find T|L for sufficiently close spots (1 < L <= 60)
        reach_dist_int = torch.round(torch.linalg.norm(
            reach_vecs, dim=-1)).long()  # (B, F)
        reach_dist_in_bounds_idx = (reach_dist_int > 1) & (reach_dist_int <=
                                                           60)
        reach_dist_in_bounds = reach_dist_int[
            reach_dist_in_bounds_idx]  # 1d tensor
        T_given_L_subset = torch.from_numpy(self.T_given_Ls_df.set_index('pass_dist').loc[reach_dist_in_bounds, 'p'].to_numpy()).float()\
            .reshape(-1, len(self.T))  # (BF~, T) ; BF~ is subset of B*F that is in [1, 60] yds from ball
        T_given_L = torch.zeros(B * len(self.field_locs),
                                len(self.T))  # (B, F, T)
        # fill in the subset of values computed above
        T_given_L[reach_dist_in_bounds_idx.flatten()] = T_given_L_subset
        T_given_L = T_given_L.reshape(B, len(self.field_locs), -1)  # (B, F, T)

        L_T_given_t = L_given_t[..., None] * T_given_L  # (B, F, T)
        L_T_given_t /= L_T_given_t.sum(
            (1, 2), keepdim=True
        )  # normalize all passes after some have been chopped off
        return L_T_given_t  # (B, F, T)

    def get_ppc_off(self, frame, p_int):
        assert self.use_ppc, 'Call made to get_ppc_off while use_ppc setting is False'
        B = frame.shape[0]
        J = p_int.shape[-1]
        ball_start = frame[:, 0, 8:10]  # (B, 2)
        player_teams = frame[:, :, 7]  # (B, J)
        reach_vecs = self.field_locs.unsqueeze(0) - ball_start.unsqueeze(
            1)  # B, F, 2
        # trajectory integration
        dx = reach_vecs[:, :, 0]  #B, F
        dy = reach_vecs[:, :, 1]  #B, F
        vx = dx[:, :, None] / self.T[None, None, :]  #F, T
        vy = dy[:, :, None] / self.T[None, None, :]  #F, T
        vz_0 = (self.T * self.g) / 2  #T

        # note that idx (i, j, k) into below arrays is invalid when j < k
        traj_ts = self.T.repeat(len(self.field_locs), len(self.T),
                                1)  #(F, T, T)
        traj_locs_x_idx = torch.round(
            torch.clip((ball_start[:, 0, None, None, None] +
                        vx.unsqueeze(-1) * self.T), 0,
                       len(self.x) - 1)).int()  # B, F, T, T
        traj_locs_y_idx = torch.round(
            torch.clip((ball_start[:, 1, None, None, None] +
                        vy.unsqueeze(-1) * self.T), 0,
                       len(self.y) - 1)).int()  # B, F, T, T
        traj_locs_z = 2.0 + vz_0.view(
            1, -1, 1) * traj_ts - 0.5 * self.g * traj_ts * traj_ts  #F, T, T
        lambda_z = torch.where(
            (traj_locs_z < self.z_max) & (traj_locs_z > self.z_min), 1,
            0)  #F, T, T
        path_idxs = (traj_locs_y_idx * self.x.shape[0] +
                     traj_locs_x_idx).long().reshape(B, -1)  # (B, F*T*T)
        # 10*traj_ts - 1 converts the times into indices - hacky
        traj_t_idxs = (10 * traj_ts - 1).long().repeat(B, 1, 1, 1).reshape(
            B, -1)  # (B, F*T*T)
        p_int_traj = torch.stack([p_int[bb, path_idxs[bb], traj_t_idxs[bb], :] for bb in range(B)])\
                        .reshape(*traj_locs_x_idx.shape, -1) * lambda_z.unsqueeze(-1)  # B, F, T, T, J
        p_int_traj_sum = p_int_traj.sum(dim=-1, keepdim=True)  # B, F, T, T, J
        norm_factor = torch.maximum(torch.ones_like(p_int_traj_sum),
                                    p_int_traj_sum)  # B, F, T, T
        p_int_traj_norm = p_int_traj / norm_factor  # B, F, T, T, J

        # independent int probs at each point on trajectory
        all_p_int_traj = torch.sum(p_int_traj_norm, dim=-1)  # B, F, T, T
        # off_p_int_traj = torch.sum((player_teams == 1)[:,None,None,None] * p_int_traj_norm, dim=-1)  # B, F, T, T
        # def_p_int_traj = torch.sum((player_teams == 0)[:,None,None,None] * p_int_traj_norm, dim=-1)  # B, F, T, T
        ind_p_int_traj = p_int_traj_norm  #use for analyzing specific players; # B, F, T, T, J

        # calc decaying residual probs after you take away p_int on earlier times in the traj
        compl_all_p_int_traj = 1 - all_p_int_traj  # B, F, T, T
        remaining_compl_p_int_traj = torch.cumprod(compl_all_p_int_traj,
                                                   dim=-1)  # B, F, T, T
        # maximum 0 because if it goes negative the pass has been caught by then and theres no residual probability
        shift_compl_cumsum = torch.roll(remaining_compl_p_int_traj, 1,
                                        dims=-1)  # B, F, T, T
        shift_compl_cumsum[:, :, :, 0] = 1

        # multiply residual prob by p_int at that location and lambda
        lambda_all = self.tti_lambda_off * player_teams + self.tti_lambda_def * (
            1 - player_teams)  # B, J
        # off_completion_prob_dt = shift_compl_cumsum * off_p_int_traj  # B, F, T, T
        # def_completion_prob_dt = shift_compl_cumsum * def_p_int_traj  # B, F, T, T
        # all_completion_prob_dt = off_completion_prob_dt + def_completion_prob_dt  # B, F, T, T
        ind_completion_prob_dt = shift_compl_cumsum.unsqueeze(
            -1) * ind_p_int_traj  # F, T, T, J

        # now accumulate values over total traj for each team and take at T=t
        # all_completion_prob = torch.cumsum(all_completion_prob_dt, dim=-1)  # B, F, T, T
        # off_completion_prob = torch.cumsum(off_completion_prob_dt, dim=-1)  # B, F, T, T
        # def_completion_prob = torch.cumsum(def_completion_prob_dt, dim=-1)  # B, F, T, T
        ind_completion_prob = torch.cumsum(ind_completion_prob_dt,
                                           dim=-2)  # B, F, T, T, J

        # this einsum takes the diagonal values over the last two axes where T = t
        # this takes care of the t > T issue.
        # ppc_all = torch.einsum('...ii->...i', all_completion_prob)  # B, F, T
        # ppc_off = torch.einsum('...ii->...i', off_completion_prob)  # B, F, T
        # ppc_def = torch.einsum('...ii->...i', def_completion_prob)  # B, F, T
        ppc_ind = torch.einsum('...iij->...ij',
                               ind_completion_prob)  # B, F, T, J
        ppc_ind *= lambda_all[:, None, None, :]
        # no_p_int_pass = 1-ppc_all  # B, F, T

        ppc_off = torch.sum(ppc_ind * player_teams[:, None, None, :],
                            dim=-1)  # B, F, T
        ppc_def = torch.sum(ppc_ind * (1 - player_teams)[:, None, None, :],
                            dim=-1)  # B, F, T

        # assert torch.allclose(all_p_int_pass, off_p_int_pass + def_p_int_pass, atol=0.01)
        # assert torch.allclose(all_p_int_pass, ind_p_int_pass.sum(-1), atol=0.01)
        # return off_p_int_pass, def_p_int_pass, ind_p_int_pass
        return ppc_off, ppc_def, ppc_ind

    def forward(self, frame):
        v_x_r = frame[:, :, 5] * self.reax_t + frame[:, :, 3]
        v_y_r = frame[:, :, 6] * self.reax_t + frame[:, :, 4]
        v_r_mag = torch.norm(torch.stack([v_x_r, v_y_r], dim=-1), dim=-1)
        v_r_theta = torch.atan2(v_y_r, v_x_r)

        x_r = frame[:, :,
                    1] + frame[:, :,
                               3] * self.reax_t + 0.5 * frame[:, :,
                                                              5] * self.reax_t**2
        y_r = frame[:, :,
                    2] + frame[:, :,
                               4] * self.reax_t + 0.5 * frame[:, :,
                                                              6] * self.reax_t**2

        # get each player's team, location, and velocity
        player_teams = frame[:, :, 7]  # B, J
        reaction_player_locs = torch.stack([x_r, y_r], dim=-1)  # (J, 2)
        reaction_player_vels = torch.stack([v_x_r, v_y_r], dim=-1)  #(J, 2)

        # calculate each player's distance from each field location
        int_d_vec = self.field_locs.unsqueeze(1).unsqueeze(
            0) - reaction_player_locs.unsqueeze(1)  #F, J, 2
        int_d_mag = torch.norm(int_d_vec, dim=-1)  # F, J
        int_d_theta = torch.atan2(int_d_vec[..., 1], int_d_vec[..., 0])  # F, J

        # take dot product of velocity and direction
        int_s0 = torch.clamp(
            torch.sum(int_d_vec * reaction_player_vels.unsqueeze(1), dim=-1) /
            int_d_mag, -1 * self.s_max.item(), self.s_max.item())  #F, J

        # calculate time it takes for each player to reach each field position accounting for their current velocity and acceleration
        t_lt_smax = (self.s_max - int_s0) / self.a_max  #F, J,
        d_lt_smax = t_lt_smax * ((int_s0 + self.s_max) / 2)  #F, J,

        # if accelerating would overshoot, then t = -v0/a + sqrt(v0^2/a^2 + 2x/a) (from kinematics)
        t_lt_smax = torch.where(d_lt_smax > int_d_mag, -int_s0 / self.a_max + \
                torch.sqrt((int_s0 / self.a_max) ** 2 + 2 * int_d_mag / self.a_max), t_lt_smax) # F, J
        d_lt_smax = torch.max(torch.min(d_lt_smax, int_d_mag),
                              torch.zeros_like(d_lt_smax))  # F, J

        d_at_smax = int_d_mag - d_lt_smax  #F, J,
        t_at_smax = d_at_smax / self.s_max  #F, J,
        t_tot = self.reax_t + t_lt_smax + t_at_smax  # F, J,

        # get true pass (tof and ball_end) to tune on (subtract 1 from tof, add 1 to y for correct indexing)
        tof = torch.round(frame[:, 0, -1]).long().view(-1, 1, 1, 1).repeat(
            1, t_tot.size(1), 1, t_tot.size(-1)) - 1

        # ball ind
        ball_end_x = frame[:, 0, -3].int()
        ball_end_y = frame[:, 0, -2].int() + 1
        ball_field_ind = (ball_end_y * self.x.shape[0] +
                          ball_end_x).long().view(-1, 1, 1).repeat(
                              1, 1, t_tot.size(-1))

        if self.tuning == TuningParam.av:
            # collapse extra dims
            tof = self.T[tof[:, 0, 0, 0]].float()

            # select field in for all the position and velocity values calculated previously
            t_lt_smax = torch.gather(t_lt_smax, 1,
                                     ball_field_ind).squeeze()  # J,
            d_lt_smax = torch.gather(d_lt_smax, 1,
                                     ball_field_ind).squeeze()  # J,
            d_at_smax = torch.gather(d_at_smax, 1, ball_field_ind).squeeze()
            t_at_smax = torch.gather(t_at_smax, 1, ball_field_ind).squeeze()
            t_tot = torch.gather(t_tot, 1, ball_field_ind).squeeze()
            int_s0 = torch.gather(int_s0, 1, ball_field_ind).squeeze()

            int_d_theta = torch.gather(int_d_theta, 1,
                                       ball_field_ind).squeeze()
            int_d_mag = torch.gather(int_d_mag, 1, ball_field_ind).squeeze()

            # projected locations at t = tof, f = ball_field_ind
            d_proj = torch.where(tof.unsqueeze(-1) <= self.reax_t, self.zero_cuda,
                    torch.where(tof.unsqueeze(-1) <= (t_lt_smax + self.reax_t),
                    (int_s0 * (tof.unsqueeze(-1) - self.reax_t)) + 0.5 * self.a_max \
                            * (tof.unsqueeze(-1) - self.reax_t) ** 2,
                    torch.where(tof.unsqueeze(-1) <= (t_lt_smax + t_at_smax + self.reax_t),
                    (d_lt_smax + (d_at_smax * (tof.unsqueeze(-1) - t_lt_smax - self.reax_t))),
                    int_d_mag))) # J,

            d_proj = torch.minimum(d_proj, int_d_mag)

            x_proj = reaction_player_locs[..., 0] + d_proj * torch.cos(
                int_d_theta)  # J
            y_proj = reaction_player_locs[..., 1] + d_proj * torch.sin(
                int_d_theta)  # J

            # mask x_proj and y_proj (only want loss on closest off and def players)
            player_mask = frame[:, :, -4]
            masked_x = player_mask * x_proj
            masked_y = player_mask * y_proj

            return torch.stack([masked_x, masked_y], dim=-1)  # J, 2

        # subtract the arrival time (t_tot) from time of flight of ball
        int_dT = self.T.view(1, 1, -1, 1) - t_tot.unsqueeze(2)  #F, T, J

        # calculate interception probability for each player, field loc, time of flight (logistic function)
        p_int = torch.sigmoid(
            (3.14 / (1.732 * self.tti_sigma)) * int_dT)  # (B, F, T, J)

        if self.tuning == TuningParam.sigma:
            p_int = torch.gather(p_int, 2, tof).squeeze()
            p_int = torch.gather(p_int, 1, ball_field_ind).squeeze()
            return p_int

        elif self.tuning == TuningParam.alpha:
            h_trans_prob = self.get_hist_trans_prob(frame)  # (B, F, T)
            if self.use_ppc:
                ppc_off, *_ = self.get_ppc_off(frame, p_int)
                trans_prob = h_trans_prob * torch.pow(
                    ppc_off, self.ppc_alpha)  # (B, F, T)
            else:
                # p_int summed over all offensive players
                p_int_off = torch.sum(p_int * (player_teams == 1),
                                      dim=-1)  # (B, F, T)
                trans_prob = h_trans_prob * torch.pow(p_int_off,
                                                      self.ppc_alpha)  # (B,)
            trans_prob /= trans_prob.sum(dim=(1, 2), keepdim=True)  # (B, F, T)
            # index into true pass. [...,0] necessary on indices because no J dimension
            trans_prob_throw = torch.gather(trans_prob, 2, tof[...,
                                                               0]).squeeze()
            trans_prob_throw = torch.gather(
                trans_prob_throw, 1, ball_field_ind[..., 0]).squeeze()  # (B,)
            return trans_prob_throw

        elif self.tuning == TuningParam.lamb:
            assert self.use_ppc, 'need to use ppc to tune lambda'
            *_, ppc_ind = self.get_ppc_off(frame,
                                           p_int)  # ppc_ind: (B, F, T, J)
            ppc_ind_throw = torch.gather(ppc_ind, 2, tof).squeeze()  # B, F, J
            ppc_ind_throw = torch.gather(ppc_ind_throw, 1,
                                         ball_field_ind).squeeze()  # B, J
            return ppc_ind_throw
Beispiel #7
0
class StochasticCNN(nn.Module):
    def __init__(
            self,
            num_training_samples: int,
            rnd: np.random.RandomState,
            num_last_units=100,
            trained_deterministic_model=None,
            prior_log_std=-3.,
            catoni_lambda=1.,
            delta=0.05,
            b=100,
            c=0.1,
            init_weights=True
    ):
        """
        :param num_training_samples: the number of training data.
        :param rnd: `np.random.RandomState` instance for reproducibility,
        :param num_last_units: The size of unis of the last linear layer.
        :param prior_log_std: initial value of prior's log std value.
        :param catoni_lambda: Catoni's Lambda Parameter. It must be positive.
        :param delta: Confidence parameter.
        :param b: Prior variance's precision parameter.
        :param c: Prior variance's upper bound.
        :param init_weights: If true, weights are initialized by truncated Gaussian. Note this value must be called to
            calculate KL or chi-square divergence.
        """
        upper_log_std = 0.5 * np.log(np.float32(c))
        assert upper_log_std > prior_log_std, 'c is the upper bound of the prior\'s variance.'

        super(StochasticCNN, self).__init__()

        self.features = self.create_features()
        self.f_last = StochasticLinear(1600, num_last_units)

        self.num_weights = 0
        self.prior_log_std = Parameter(torch.Tensor([prior_log_std]))
        self.previous_prior_log_std = torch.Tensor([prior_log_std])  # for constraints

        if init_weights:
            if trained_deterministic_model is not None:
                self._initialize_weights_from_deterministic_model(trained_deterministic_model)
            else:
                self._initialize_weights(rnd)

        self.num_training_samples = Parameter(torch.Tensor([num_training_samples]), requires_grad=False)

        self.delta = Parameter(torch.Tensor([delta]), requires_grad=False)
        self.b = Parameter(torch.Tensor([b]), requires_grad=False)
        self.c = Parameter(torch.Tensor([c]), requires_grad=False)
        self.catoni_lambda = Parameter(torch.Tensor([catoni_lambda]), requires_grad=False)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.f_last(x)
        return x

    def sample_noise(self) -> None:
        """
        Sample weights and biases from the posterior.

        :return: None
        """
        for m in self.modules():
            if isinstance(m, (StochasticLinear, StochasticConv2D)):
                m.sample_noise()

    @staticmethod
    def create_features() -> torch.nn.modules.container.Sequential:
        return nn.Sequential(
            StochasticConv2D(3, 64, kernel_size=5, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            StochasticConv2D(64, 64, kernel_size=5, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

    def _initialize_weights_from_deterministic_model(self, deterministic_model) -> None:
        for m, premodel_module in zip(self.modules(), deterministic_model.modules()):
            if isinstance(m, (StochasticLinear, StochasticConv2D)):
                m.weight.data.copy_(premodel_module.weight.data)
                m.weight_prior.data.copy_(premodel_module.weight.data)

                m.bias.data.copy_(premodel_module.bias.data)
                m.bias_prior.data.copy_(premodel_module.bias.data)

                nn.init.constant_(m.weight_log_std, self.prior_log_std.item())
                nn.init.constant_(m.bias_log_std, self.prior_log_std.item())

                self.num_weights += np.prod(m.weight.size())
                self.num_weights += np.prod(m.bias.size())

    def _initialize_weights(self, rnd: np.random.RandomState) -> None:
        conv_upper = 2 * 5e-2
        for m in self.modules():
            if isinstance(m, StochasticLinear):
                m.weight_prior.data = torch.from_numpy(
                    truncnorm.rvs(
                        a=-1. / 800., b=1. / 800., size=(100, 1600), random_state=rnd
                    ).astype(np.float32)
                )

            elif isinstance(m, StochasticConv2D):
                m.weight_prior.data = torch.from_numpy(
                    truncnorm.rvs(
                        a=-conv_upper, b=conv_upper, size=tuple(m.weight.size()), random_state=rnd
                    ).astype(np.float32)
                )

            if isinstance(m, (StochasticLinear, StochasticConv2D)):
                m.weight.data.copy_(m.weight_prior.data)

                nn.init.constant_(m.bias_prior, 0.)
                nn.init.constant_(m.bias, 0.)

                nn.init.constant_(m.weight_log_std, self.prior_log_std.item())
                nn.init.constant_(m.bias_log_std, self.prior_log_std.item())

                self.num_weights += np.prod(m.weight.size())
                self.num_weights += np.prod(m.bias.size())

    def kl(self, prior_log_std) -> torch.FloatTensor:
        """
        Calculate KL divergence between posterior and prior.

        :param prior_log_std:

        :return: KL divergence value.
        """
        num_weights = self.num_weights
        assert num_weights > 0
        mean_norm_list = []
        log_std_sum_list = []
        variance_l1_norm_list = []
        prior_variance = torch.exp(2. * prior_log_std)
        prior_log_variance = 2. * prior_log_std

        for m in self.modules():
            if isinstance(m, (StochasticLinear, StochasticConv2D)):
                mean_norm_list.append(torch.sum((m.weight - m.weight_prior) ** 2))
                mean_norm_list.append(torch.sum((m.bias - m.bias_prior) ** 2))

                log_std_sum_list.append(torch.sum(m.weight_log_std))
                log_std_sum_list.append(torch.sum(m.bias_log_std))

                # negative `prior` term provides more accurate a part of term in KL than
                # `torch.exp(2. * m.bias_log_std)` then is divided by `prior_variance`
                variance_l1_norm_list.append(torch.sum(torch.exp(2. * m.weight_log_std - prior_log_variance)))
                variance_l1_norm_list.append(torch.sum(torch.exp(2. * m.bias_log_std - prior_log_variance)))

        norm_weights = torch.sum(torch.stack(mean_norm_list))
        mean_part = norm_weights / prior_variance

        norm_log_std = 2. * torch.sum(torch.stack(log_std_sum_list))
        sum_variance_l1_norm = torch.sum(torch.stack(variance_l1_norm_list))

        std_part = sum_variance_l1_norm - norm_log_std + 2. * num_weights * prior_log_std

        kl = 0.5 * (mean_part + std_part - num_weights)
        return kl

    def union_bound(self, prior_log_std) -> torch.FloatTensor:
        """
        Calculate union bound value related to prior.

        :param prior_log_std: Float parameter contains prior's log std.

        :return: FloatTensor
        """
        return 2. * torch.log(self.b) + 2. * torch.log(torch.log(self.c) - 2. * prior_log_std) \
               + torch.log(pi ** 2 / (6. * self.delta))

    def pac_bayes_objective(self, contrastive_loss: torch.FloatTensor) -> tuple:
        """
        Catoni's PAC-Bayes bound with union bound

        :param contrastive_loss: empirical risk: FloatTensor
        :return: PAC-Bayes upper bound, KL, and complexity term; FloatTensors
        """
        kl = self.kl(prior_log_std=self.prior_log_std)

        # union bound term
        union_bound = self.union_bound(prior_log_std=self.prior_log_std)

        # KL term easily becomes large, so it is divided by Catoni's lambda
        objective = contrastive_loss + (kl + union_bound) / self.catoni_lambda

        return objective, kl, union_bound

    def compute_complexity_terms_with_discretized_prior_variance(self) -> tuple:
        """
        Compute kl divergence and union bound terms by using discretized_prior_variance.
        Note `log (2 \sqrt{m})` is added to the union bound term in `contrastive.eval.common.pb_parameter_selection`.

        :return: tuple of kl divergence and union bound without sqrt{m}.
        """

        # discretize prior's variance parameter
        # https://github.com/gkdziugaite/pacbayes-opt/blob/master/snn/core/network.py#L398
        discretized_j = (self.b * (torch.log(self.c) - 2. * self.prior_log_std))

        discretized_j_up = torch.ceil(discretized_j)
        discretized_j_down = torch.floor(discretized_j)

        constant_in_log_delta = torch.log(np.pi ** 2 / (6 * self.delta))
        union_up = (constant_in_log_delta + 2 * torch.log(discretized_j_up)).item()
        union_down = (constant_in_log_delta + 2 * torch.log(discretized_j_down)).item()

        prior_log_std_up = (torch.log(self.c) - discretized_j_up / self.b) / 2.
        prior_log_std_down = (torch.log(self.c) - discretized_j_down / self.b) / 2.

        kl_up = self.kl(prior_log_std_up).item()
        kl_down = self.kl(prior_log_std_down).item()

        up_complexity = kl_up + union_up
        down_complexity = kl_down + union_down

        if up_complexity < down_complexity or np.isinf(down_complexity):
            return kl_up, union_up
        else:
            return kl_down, union_down

    def deterministic(self) -> None:
        """
        Set mean values to weights for feed forwarding.

        :return: None
        """
        for m in self.modules():
            if isinstance(m, (StochasticLinear, StochasticConv2D)):
                m.realised_weight = m.weight.detach()
                m.realised_bias = m.bias.detach()

    def constraint(self) -> None:
        """
        Constraint for prior's variance.

        :return: None
        """
        if (torch.log(self.c) - 2. * self.prior_log_std).item() > 0.:
            self.previous_prior_log_std.data.copy_(self.prior_log_std.data)
        else:
            self.prior_log_std.data.copy_(self.previous_prior_log_std.data)
class _LearnableFakeQuantize(nn.Module):
    r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
    supports more generalized lower-bit quantization and support learning of the scale
    and zero point parameters through backpropagation. For literature references,
    please see the class _LearnableFakeQuantizePerTensorOp.

    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
    module also includes the following attributes to support quantization parameter learning.

    * :attr: `channel_len` defines the length of the channel when initializing scale and zero point
             for the per channel case.

    * :attr: `use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
              normalized by the constant, which is proportional to the square root of the number of
              elements in the tensor. The related literature justifying the use of this particular constant
              can be found here: https://openreview.net/pdf?id=rkgO66VKDS.

    * :attr: `fake_quant_enabled` defines the flag for enabling fake quantization on the output.

    * :attr: `static_enabled` defines the flag for using observer's static estimation for
             scale and zero point.

    * attr: `learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
    """
    def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
                 use_grad_scaling=False, **observer_kwargs):
        super(_LearnableFakeQuantize, self).__init__()
        assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
        self.quant_min = quant_min
        self.quant_max = quant_max
        self.use_grad_scaling = use_grad_scaling

        if channel_len == -1:
            self.scale = Parameter(torch.tensor([scale]))
            self.zero_point = Parameter(torch.tensor([zero_point]))
        else:
            assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
            self.scale = Parameter(torch.tensor([scale] * channel_len))
            self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))

        self.activation_post_process = observer(**observer_kwargs)
        assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
               'quant_min out of bound'
        assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
               'quant_max out of bound'
        self.dtype = self.activation_post_process.dtype
        self.qscheme = self.activation_post_process.qscheme
        self.ch_axis = self.activation_post_process.ch_axis \
            if hasattr(self.activation_post_process, 'ch_axis') else -1
        self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
        self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))

        bitrange = torch.tensor(quant_max - quant_min + 1).double()
        self.bitwidth = int(torch.log2(bitrange).item())

    @torch.jit.export
    def enable_param_learning(self):
        r"""Enables learning of quantization parameters and
        disables static observer estimates. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=True) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=False)
        return self

    @torch.jit.export
    def enable_static_estimate(self):
        r"""Enables static observer estimates and disbales learning of
        quantization parameters. Forward path returns fake quantized X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=True) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def enable_static_observation(self):
        r"""Enables static observer accumulating data from input but doesn't
        update the quantization parameters. Forward path returns the original X.
        """
        self.toggle_qparam_learning(enabled=False) \
            .toggle_fake_quant(enabled=False) \
            .toggle_observer_update(enabled=True)

    @torch.jit.export
    def toggle_observer_update(self, enabled=True):
        self.static_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def toggle_qparam_learning(self, enabled=True):
        self.learning_enabled[0] = int(enabled)
        self.scale.requires_grad = enabled
        self.zero_point.requires_grad = enabled
        return self

    @torch.jit.export
    def toggle_fake_quant(self, enabled=True):
        self.fake_quant_enabled[0] = int(enabled)
        return self

    @torch.jit.export
    def observe_quant_params(self):
        print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach()))
        print('_LearnableFakeQuantize Zero Point: {}'.format(self.zero_point.detach()))

    @torch.jit.export
    def calculate_qparams(self):
        return self.activation_post_process.calculate_qparams()

    def forward(self, X):
        self.activation_post_process(X.detach())
        _scale, _zero_point = self.calculate_qparams()
        _scale = _scale.to(self.scale.device)
        _zero_point = _zero_point.to(self.zero_point.device)

        if self.static_enabled[0] == 1:
            self.scale.data.copy_(_scale)
            self.zero_point.data.copy_(_zero_point)

        if self.fake_quant_enabled[0] == 1:
            if self.learning_enabled[0] == 1:
                if self.use_grad_scaling:
                    grad_factor = 1.0 / (self.weight.numel() * self.quant_max) ** 0.5
                else:
                    grad_factor = 1.0
                if self.qscheme in (
                        torch.per_channel_symmetric, torch.per_channel_affine):
                    X = _LearnableFakeQuantizePerChannelOp.apply(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max, grad_factor)
                else:
                    X = _LearnableFakeQuantizePerTensorOp.apply(
                        X, self.scale, self.zero_point,
                        self.quant_min, self.quant_max, grad_factor)
            else:
                if self.qscheme == torch.per_channel_symmetric or \
                        self.qscheme == torch.per_channel_affine:
                    X = torch.fake_quantize_per_channel_affine(
                        X, self.scale, self.zero_point, self.ch_axis,
                        self.quant_min, self.quant_max)
                else:
                    X = torch.fake_quantize_per_tensor_affine(
                        X, float(self.scale.item()), int(self.zero_point.item()),
                        self.quant_min, self.quant_max)

        return X

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        # We will be saving the static state of scale (instead of as a dynamic param).
        super(_LearnableFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + 'scale'] = self.scale.data
        destination[prefix + 'zero_point'] = self.zero_point

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        local_state = ['scale', 'zero_point']
        for name in local_state:
            key = prefix + name
            if key in state_dict:
                val = state_dict[key]
                if name == 'scale':
                    self.scale.data.copy_(val)
                else:
                    setattr(self, name, val)
            elif strict:
                missing_keys.append(key)
        super(_LearnableFakeQuantize, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys,
            unexpected_keys, error_msgs)

    with_args = classmethod(_with_args)
Beispiel #9
0
class GPRegressor(nn.Module):
    def __init__(self, kernel, sn=0.1, lr=1e-1, scheduler=False, prior=True):
        super(GPRegressor, self).__init__()
        self.sn = Parameter(torch.Tensor([sn]))
        self.kernel = kernel
        self.loss_func = NLMLLoss()
        opt = [p for p in self.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(opt, lr=lr)
        if prior:
            self.prior = torch.distributions.Beta(2, 2).log_prob
        else:
            self.prior = None
        if scheduler:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=2, verbose=True, mode='max')
        else:
            self.scheduler = None

    def loss(self, X, y, jitter, val=None):
        K = self.kernel(X, X)
        inds = list(range(len(K)))
        K[[inds], [inds]] += self.sn + jitter
        L = torch.potrf(K, upper=False)
        alpha = torch.trtrs(y, L, upper=False)[0]
        alpha = torch.trtrs(alpha, L.t(), upper=True)[0]
        loss = self.loss_func(L, alpha, y)
        if self.prior is not None:
            loss -= self.prior(self.sn)

        if val is not None:
            X_val, y_val = val
            k_star = self.kernel(X, X_val)
            mu = k_star.t() @ alpha
            mse = nn.MSELoss()(mu, y_val)
            return loss, mse
        else:
            return loss

    def forward(self, X):
        """ Gaussian process regression predictions.

        Parameters:
            X: m x d points to predict

        Returns:
            mu: m x 1 predicted means
            var: m x m predicted covariance

        Follows Algorithm 2.1 from GPML.
        """
        ### Implement prior ###
        ### Scaling
        k_star = self.kernel(self.X, X)
        mu = k_star.t() @ self.alpha
        v = torch.trtrs(k_star, self.L, upper=False)[0]
        k_ss = self.kernel(X, X)
        var = k_ss - v.t() @ v
        return mu, var

    def fit(self,
            X,
            y,
            its=100,
            jitter=1e-6,
            verbose=True,
            val=None,
            chkpt=None):
        self.X = X
        self.y = y
        self._fit(X, y, its, jitter, verbose, val, chkpt)
        self._set_pars(jitter)
        return self.history

    def _fit(self, X, y, its, jitter, verbose, val, chkpt):
        self.history = []
        if val is not None and chkpt is not None:
            best_mse = 1e14
        for it in range(its):
            if val is not None:
                loss, mse = self.loss(X, y, jitter, val=val)
                mse = mse.item()
                if chkpt is not None and mse < best_mse:
                    torch.save(self.state_dict(), chkpt)
            else:
                loss = self.loss(X, y, jitter)
            # backward
            self.optimizer.zero_grad()
            loss.backward(retain_graph=False)
            # update parameters
            self.optimizer.step()
            self.sn.data.clamp_(min=1e-6)
            # if self.scheduler is not None:
            #     self.scheduler.step(loss)
            if verbose:
                update = '\rIteration %d of %d\tNLML: %.4f\tsn: %.6f\t' \
                        %(it + 1, its, loss, self.sn.cpu().detach().numpy()[0])
                print(update, end='')
                if val is not None:
                    print('val mse: %.4f' % mse, end='')
            if val is None:
                h = (loss.item(), self.sn.item())
            else:
                h = (loss.item(), self.sn.item(), mse)
                del mse
            self.history.append(h)
            del loss

    def _set_pars(self, jitter):
        Ky = self.kernel(self.X, self.X)
        inds = list(range(len(Ky)))
        Ky[[inds], [inds]] += self.sn + jitter
        self.L = torch.potrf(Ky, upper=False)
        self.alpha = torch.trtrs(self.y, self.L, upper=False)[0]
        self.alpha = torch.trtrs(self.alpha, self.L.t(), upper=True)[0]