def load_part_model(self, action_model_path=None, rnn_path=None):

        # load action net
        if action_model_path != None:

            act_data = torch.load(action_model_path)
            # act_data = torch.load('./action_net_model.pwf')

            ## to remove module
            new_state_dict = OrderedDict()
            for k, v in act_data.items():
                # if k.find('module') != -1 :
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            act_net = ACT_net(self.classes, self.sample_duration)

            act_net.create_architecture()
            act_net.load_state_dict(new_state_dict)
            self.act_net = act_net

        else:
            self.act_net = ACT_net(self.classes, self.sample_duration)
            self.act_net.create_architecture()

        # load lstm
        if rnn_path != None:

            act_rnn = Act_RNN(self.p_feat_size, int(self.p_feat_size / 2),
                              self.n_classes)

            act_rnn_data = torch.load(rnn_path)
            act_rnn.load(act_rnn_data)
            self.act_rnn = act_rnn

        else:
            self.act_rnn = Act_RNN(self.p_feat_size, int(self.p_feat_size),
                                   self.n_classes)
Пример #2
0
    spatial_transform = Compose([
        Scale(sample_size),  # [Resize(sample_size),
        ToTensor(),
        Normalize(mean, [1, 1, 1])
    ])
    temporal_transform = LoopPadding(sample_duration)

    # Init action_net
    model = ACT_net(actions, sample_duration)
    model.create_architecture()
    model = nn.DataParallel(model)
    model.to(device)

    # model_data = torch.load('./actio_net_model_both.pwf')
    # model_data = torch.load('./action_net_model_both_without_avg.pwf')
    # model_data = torch.load('./action_net_model_16frm_64.pwf')
    # model_data = torch.load('./action_net_model_both_sgl_frm.pwf')
    model_data = torch.load('./action_net_model_both.pwf')
    #
    # model_data = torch.load('./action_net_model_part1_1_8frm.pwf')
    model.load_state_dict(model_data)

    # model_data = torch.load('./region_net_8frm.pwf')
    # model.module.act_rpn.load_state_dict(model_data)

    model.eval()

    validation(0, device, model, dataset_folder, sample_duration,
               spatial_transform, temporal_transform, boxes_file,
               split_txt_path, cls2idx, 4, n_threads)
    print(' -----------------------------------------------------')
    print('|          Part 1-1 - train TPN - without reg         |')
    print(' -----------------------------------------------------')

    # Init action_net
    act_model = ACT_net(actions, sample_duration)
    act_model.create_architecture()
    if torch.cuda.device_count() > 1:
        print('Using {} GPUs!'.format(torch.cuda.device_count()))

    act_model = nn.DataParallel(act_model)
    act_model.to(device)

    model_data = torch.load('./action_net_model_16frm.pwf')
    act_model.load_state_dict(model_data)


    # lr = 0.1
    lr = 0.00001
    lr_decay_step = 10
    lr_decay_gamma = 0.1
    
    params = []

    # for p in act_model.module.reg_layer.parameters() : p.requires_grad=False

    for key, value in dict(act_model.named_parameters()).items():
        # print(key, value.requires_grad)
        if value.requires_grad:
            print('key :',key)
    def load_part_model(self,
                        resnet_path=None,
                        action_model_path=None,
                        rnn_path=None):

        # load action net
        if action_model_path != None:

            act_data = torch.load(action_model_path)
            # act_data = torch.load('./action_net_model.pwf')

            ## to remove module
            new_state_dict = OrderedDict()
            for k, v in act_data.items():
                # if k.find('module') != -1 :
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            act_net = ACT_net(self.classes, self.sample_duration)
            if resnet_path is not None:
                act_net.create_architecture(model_path=resnet_path)
            else:
                act_net.create_architecture()
            act_net.load_state_dict(new_state_dict)
            self.act_net = act_net

        else:
            self.act_net = ACT_net(self.classes, self.sample_duration)
            if resnet_path is not None:
                self.act_net.create_architecture(model_path=resnet_path)
            else:
                self.act_net.create_architecture()

        # load lstm
        if rnn_path != None:

            # act_rnn = Act_RNN(self.p_feat_size,int(self.p_feat_size/2),self.n_classes)
            # act_rnn_data = torch.load(rnn_path)
            # act_rnn.load_state_dict(act_rnn_data)

            act_rnn = nn.Sequential(
                # nn.Linear(64*self.sample_duration, 256),
                # nn.ReLU(True),
                # nn.Dropout(0.8),
                # nn.Linear(256,self.n_classes),
                nn.Linear(64 * self.sample_duration, self.n_classes),
                # nn.ReLU(True),
                # nn.Dropout(0.8),
                # nn.Linear(256,self.n_classes),
            )
            act_rnn_data = torch.load(rnn_path)
            act_rnn.load_state_dict(act_rnn_data)
            self.act_rnn = act_rnn

        else:
            # self.act_rnn =Act_RNN(self.p_feat_size,int(self.p_feat_size/2),self.n_classes)
            self.act_rnn = nn.Sequential(
                # nn.Linear(64*self.sample_duration, 256),
                # nn.ReLU(True),
                # nn.Dropout(0.8),
                # nn.Linear(256,self.n_classes),
                nn.Linear(64 * self.sample_duration, self.n_classes),
                # nn.ReLU(True),
                # nn.Dropout(0.8),
                # nn.Linear(256,self.n_classes),
            )
            for m in self.act_rnn.modules():
                if m == nn.Linear:
                    m.weight.data.normal_().fmod_(2).mul_(stddev).add_(
                        mean)  # not a perfect approximation