示例#1
0
文件: cpm_test.py 项目: phymucs/MSBR
def construct_model(pre_model_path):

    model = cpm_model.CPM(k=14)
    state_dict = torch.load(pre_model_path)['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    return model
示例#2
0
        def build_detectors(self):
            self.detector = cpm_model.CPM(k=self.k, stages=self.stages)
            self.detector = torch.nn.DataParallel(self.detector, device_ids=self.args.gpu).cuda()

            # reload paras for pretrained self.detector is availble
            if self.args.pretrained_d != 'None':
                state_dict = torch.load(self.args.pretrained_d)['state_dict']
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():

                    name = k[7:]
                    new_state_dict[name] = v
                self.detector.load_state_dict(new_state_dict)
            pass
示例#3
0
def construct_model(args):

    model = cpm_model.CPM(k=14)
    # load pretrained model
    # state_dict = torch.load(args.pretrained)['state_dict']
    # from collections import OrderedDict
    # new_state_dict = OrderedDict()
    # for k, v in state_dict.items():
    #
    #     name = k[7:]
    #     new_state_dict[name] = v
    # model.load_state_dict(new_state_dict)

    model = torch.nn.DataParallel(model, device_ids=args.gpu).cuda()

    return model
示例#4
0
def construct_model(args):

    model = cpm_model.CPM(k=14)
    # load pretrained model
    if args.fine_tune:
        state_dict = torch.load(args.pretrained)['state_dict']
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():

            name = k[7:]
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

    if args.gpu[0] < 0:
        return model
    else:
        model = torch.nn.DataParallel(model, device_ids=args.gpu).cuda()

        return model
save_test_heatmap_path = utils.mkdir(os.path.join(path_root, "heatmaps/"))
if if_sum:
    save_sum_path = os.path.join(path_root, "sum_history.json")
if if_max:
    save_max_path = os.path.join(path_root, "max_history.json")

transform = transforms.Compose([transforms.ToTensor()])
#test data loader
data_dir = '/mnt/UCIHand/test/test_data'
label_dir = '/mnt/UCIHand/test/test_label'
dataset = UCIHandPoseDataset(data_dir=data_dir, label_dir=label_dir)
test_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# In[4]:

net = cpm_model.CPM(out_c=nb_joint, background=background).cuda()
state_dict = torch.load(ckpt_path)

# create new OrderedDict that does not contain `module.`
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    namekey = k[7:]  # remove `module.`
    new_state_dict[namekey] = v
# load params
net.load_state_dict(new_state_dict)
net = torch.nn.DataParallel(net)

# In[5]:

test_history = {}
net.eval()
def train():
    train_data = cpm_data.LSPDataset(args.lsp_root,
                                     transform=transforms.Compose([cpm_data.Scale(image_h, image_w),
                                                                   cpm_data.RandomHSV((0.8, 1.2),
                                                                                      (0.8, 1.2),
                                                                                      (25, 25)),
                                                                   cpm_data.ToTensor()]),
                                     phase_train=True,
                                     weighted_loss=args.weighted_loss,
                                     bandwidth=args.weighted_bandwidth)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True if args.cuda else False)
    num_alldata = len(train_data)

    model = cpm_model.CPM(train_data.num_keypoints)
    model.train()
    if args.cuda:
        model.cuda()
    lr = args.lr
    params = [
        {'params': [p for n, p in model.named_parameters() if 'stage1.weight' in n], 'lr': 5 * lr},
        {'params': [p for n, p in model.named_parameters() if 'stage1.bias' in n], 'lr': 10 * lr},
        {'params': [p for n, p in model.named_parameters() if 'weight' in n and 'stage1' not in n], 'lr': lr},
        {'params': [p for n, p in model.named_parameters() if 'bias' in n and 'stage1' not in n], 'lr': 2 * lr}
    ]
    optimizer = torch.optim.SGD(params, momentum=args.momentum, weight_decay=args.weight_decay)

    count = 0
    global_step = 0
    if args.resume:
        count, global_step, args.start_epoch = load_ckpt(model, optimizer, args.resume)

    writer = SummaryWriter(args.summary_dir)

    for epoch in range(int(args.start_epoch), args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr, params,
                             args.lr_decay_rate, args.lr_epoch_per_decay)
        for batch_idx, sample in enumerate(train_loader):

            image = Variable(sample['image'].cuda() if args.cuda else sample['image'])
            gt_map = Variable(sample['gt_map'].cuda() if args.cuda else sample['gt_map'])
            center_map = Variable(sample['center_map'].cuda() if args.cuda else sample['center_map'])
            weight = Variable(sample['weight'].cuda() if args.cuda else sample['weight'])
            optimizer.zero_grad()
            pred_6 = model(image, center_map)
            loss_log = cpm_model.mse_loss(pred_6, gt_map, weight)
            loss = cpm_model.mse_loss(pred_6, gt_map, weight, weighted_loss=args.weighted_loss)
            loss.backward()
            optimizer.step()
            count += image.data.shape[0]
            global_step += 1
            if global_step % args.print_freq == 0:
                try:
                    time_inter = time.time() - end_time
                    count_inter = count - last_count
                    print_log(global_step, epoch, count, count_inter,
                              num_alldata, loss_log, time_inter)
                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step)
                    grid_image = make_grid(tools.keypoint_painter(image[:2], pred_6[:2],
                                                                  image_h, image_w), 3)
                    writer.add_image('Predicted image', grid_image, global_step)
                    grid_image = make_grid(tools.keypoint_painter(image[:2], gt_map[:2],
                                                                  image_h, image_w, phase_gt=True,
                                                                  center_map=center_map), 6)
                    writer.add_image('Groundtruth image', grid_image, global_step)
                    writer.add_scalar('MSELoss', loss_log.data[0], global_step=global_step)
                    writer.add_scalar('Weighted MSELoss', loss.data[0], global_step=global_step)
                    end_time = time.time()
                    last_count = count
                except NameError:
                    end_time = time.time()
                    last_count = count
            if (global_step % args.save_freq == 0 or global_step == 1) and args.cuda:
                save_ckpt(model, optimizer, global_step, batch_idx, count, args.batch_size,
                          num_alldata, args.weighted_loss, args.weighted_bandwidth)
                pass
    print("Training completed ")