def __init__(self):
        super(TSN_BIT, self).__init__()
        self.tsn = TSN(num_class, num_segments=num_segments, modality=modality,
            base_model=arch,
            consensus_type=crop_fusion_type,
            dropout=0.7)

        self.activation = nn.LeakyReLU()
        self.fc1 = nn.Linear(101, 32)
        self.fc2 = nn.Linear(32, 8)
        self.model_name = 'TSN_TemResGen_Kalman_2019-03-11_15-28-57.pth'
        self._load_pretrained_model(self.model_name)
    def __init__(self):
        super(TSN_BIT, self).__init__()
        self.tsn = TSN(num_class, num_segments=num_segments, modality=modality,
            base_model=arch,
            consensus_type=crop_fusion_type,
            dropout=0.7)

        self.activation = nn.LeakyReLU()
        self.fc1 = nn.Linear(51, 32)
        self.fc2 = nn.Linear(32, 21)
        self.model_name = 'TSN_RGB_2019-01-24_12-26-11.pth'
        self._load_pretrained_model(self.model_name)
    def __init__(self):
        super(TSN_BIT, self).__init__()
        self.tsn = TSN(num_class, num_segments=num_segments, modality=modality,
            base_model=arch,
            consensus_type=crop_fusion_type,
            dropout=0.7)

        self.activation = nn.LeakyReLU()
        self.fc1 = nn.Linear(101, 32)
        self.fc2 = nn.Linear(32, 8)
        self.model_name = 'TSN_Flow_2019-01-23_17-06-15.pth'
        # self._load_tsn_rgb_weight()
        self._load_pretrained_model(self.model_name)
class TSN_BIT(nn.Module):

    def __init__(self):
        super(TSN_BIT, self).__init__()
        self.tsn = TSN(num_class, num_segments=num_segments, modality=modality,
            base_model=arch,
            consensus_type=crop_fusion_type,
            dropout=0.7)

        self.activation = nn.LeakyReLU()
        self.fc1 = nn.Linear(101, 32)
        self.fc2 = nn.Linear(32, 8)
        self.model_name = 'TSN_TemResGen_Kalman_2019-03-11_15-28-57.pth'
        self._load_pretrained_model(self.model_name)

    def _load_pretrained_model(self, model_name):

        """
            Load pretrained model that contains all weights for all layers;
            Allow missing parameters;
        """

        checkpoint = torch.load('/home/zhufl/Temporal-Residual-Motion-Generation/videoPrediction/BIT_train_test/' + model_name)
        print("Number of parameters recovered from modeo {} is {}".format(model_name, len(checkpoint)))

        model_state = self.state_dict()
        base_dict = {k:v for k, v in checkpoint.items() if k in model_state}

        missing_dict = {k:v for k, v in model_state.items() if k not in base_dict}
        for key, value in missing_dict.items():
            print("Missing motion branch param {}".format(key))

        model_state.update(base_dict)
        self.load_state_dict(model_state)


    def _load_tsn_rgb_weight(self):
        """
            Loading Flow Weights and then fine-tune fc layers
        """

        flow_weights = '/home/zhufl/Workspace/tsn-pytorch/ucf101_flow.pth'
        checkpoint = torch.load(flow_weights)

        base_dict = {}
        count = 0
        for k, v in checkpoint.items():

            count = count + 1
            print count, k
            if 415>count>18:
                base_dict.setdefault(k[7:], checkpoint[k])
            if count<19:
                base_dict.setdefault(k, checkpoint[k])
                base_dict.setdefault('new_fc.weight', checkpoint['base_model.fc-action.1.weight'])
                base_dict.setdefault('new_fc.bias', checkpoint['base_model.fc-action.1.bias'])

        self.tsn.load_state_dict(base_dict)

    @deprecated
    def forward(self, input):
        x = self.activation(self.tsn(input))
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

    def fea_gen_forward(self, input, batch_size, warmup_t, pred_t):
        """
            For testing, return 5 parameters:
                1. gen_fea;
                2. org_fea;
                3. gen_fea_grad;
                4. org_fea_grad;
                5. kalman_gain;
        """
        x, y, _, _, kalman_gain = self.tsn.fea_gen_forward(input, batch_size, warmup_t, pred_t)
        return x, y

    def gen_fea_cls(self, input):
        """
            Wrapper func for model_temp_res_gen.TSN.gen_fea_cls;

            Input: Intermediate feature [batch * num_segments/data_length, 192, 28, 28];
            Call: gen_fea_cls of model_temp_res_gen.TSN, to output [batch, 101/51], depends which weight is being used;
            Ouput: [batch, num_class]
        """
        x = self.tsn.gen_fea_cls(input)
        x = self.activation(x)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

    def forward(self, input, warmup_t, pred_t):

        b_shape = int(input.shape[0] / num_segments)
        gen_fea, org_fea = self.fea_gen_forward(input, b_shape, warmup_t, pred_t)
        gen_fea = torch.stack(gen_fea).transpose_(0, 1).contiguous().view(-1, 192, 28, 28)
        org_fea = torch.stack(org_fea).transpose_(0, 1).contiguous().view(-1, 192, 28, 28)

        pred = self.gen_fea_cls(gen_fea)
        return pred
class TSN_BIT(nn.Module):

    def __init__(self):
        super(TSN_BIT, self).__init__()
        self.tsn = TSN(num_class, num_segments=num_segments, modality=modality,
            base_model=arch,
            consensus_type=crop_fusion_type,
            dropout=0.7)

        self.activation = nn.LeakyReLU()
        self.fc1 = nn.Linear(101, 32)
        self.fc2 = nn.Linear(32, 8)
        self.model_name = 'TSN_Flow_2019-01-23_17-06-15.pth'
        # self._load_tsn_rgb_weight()
        self._load_pretrained_model(self.model_name)

    def _load_pretrained_model(self, model_name):

        """
            Load pretrained model that contains all weights for all layers;
            Allow missing parameters;
        """

        checkpoint = torch.load('/home/zhufl/videoPrediction/BIT_train_test/' + model_name)
        print("Number of parameters recovered from modeo {} is {}".format(model_name, len(checkpoint)))

        model_state = self.state_dict()
        base_dict = {k:v for k, v in checkpoint.items() if k in model_state}

        missing_dict = {k:v for k, v in model_state.items() if k not in base_dict}
        for key, value in missing_dict.items():
            print("Missing Param {}".format(key))

        model_state.update(base_dict)
        self.load_state_dict(model_state)


    def _load_tsn_rgb_weight(self):
        """
            Loading Flow Weights and then fine-tune fc layers
        """

        flow_weights = '/home/zhufl/Workspace/tsn-pytorch/ucf101_flow.pth'
        checkpoint = torch.load(flow_weights)

        base_dict = {}
        count = 0
        for k, v in checkpoint.items():

            count = count + 1
            print count, k
            if 415>count>18:
                base_dict.setdefault(k[7:], checkpoint[k])
            if count<19:
                base_dict.setdefault(k, checkpoint[k])
                base_dict.setdefault('new_fc.weight', checkpoint['base_model.fc-action.1.weight'])
                base_dict.setdefault('new_fc.bias', checkpoint['base_model.fc-action.1.bias'])

        self.tsn.load_state_dict(base_dict)

    def forward(self, input):

        x = self.activation(self.tsn(input))
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

    def fea_gen_forward(self, input, batch_size, warmup_t, pred_t):
        """
            return:
                1. fea_gen;
                2. org_fea;
                3. diff_gen;
                4. org_diff;
                5. diff_gen_grad;
                6. org_diff_grad;
                7. gen_fea_grad;
                8. org_fea_grad;
                9. kalman_gain;
        """
        x, y, z, w, m, n, q, p, u = self.tsn.fea_gen_forward(input, batch_size, warmup_t, pred_t)
        return x, y, z, w, m, n, q, p, u