def main():
    # 1 use image and joint heat maps as input
    # 0 use image only as input
    input_option = 0
    model = resnet34_Mano(input_option=input_option)
    model.cuda()
    mdict = torch.load('data/model-' + str(input_option) + '-module.pth')
    model.load_state_dict(mdict)

    test(input_option,
         model,
         out_path=
         "/home/workspace2/dataset/3dhand/visual_test/test2/out_author/",
         data_pth="/home/workspace2/dataset/3dhand/visual_test/test2/image/")
def main2():
    ls = [91000]
    #ls = [10000,13000,15000,17000,19000,21000,30000,40000,50000,60000,70000]
    for iter in ls:
        input_option = 0
        model_pth = "/home/workspace2/checkpoints/3dhand/train/train_model0_3d_norm_no_detach/checkpoints/model-0_%08d.pth" % iter
        img_pth = "/home/workspace2/dataset/3dhand/visual_test/test2/image/"
        assert os.path.isfile(model_pth)
        model_name = os.path.splitext(os.path.split(model_pth)[1])[0]
        out_pth = "/home/workspace2/dataset/3dhand/visual_test/test2/" + model_name
        if not os.path.exists(out_pth):
            os.makedirs(out_pth)

        #model = torch.nn.DataParallel(resnet34_Mano(input_option=input_option), device_ids=[0])
        model = resnet34_Mano(input_option=input_option)
        model.load_state_dict(torch.load(model_pth))
        model.cuda()
        test(input_option, model, out_pth, data_pth=img_pth)
    def __init__(self, params, ispretrain):
        super(EncoderTrainer, self).__init__()
        self.ispretrain = ispretrain
        self.input_option = params['input_option']
        self.weight = params

        # initiate the network modules
        #self.model = resnet34_Mano(ispretrain=ispretrain, input_option=params['input_option'])
        self.model = torch.nn.DataParallel(
            resnet34_Mano(input_option=params['input_option']))
        self.model = self.model.module
        self.mean_3d = torch.zeros(3)

        # setup the optimizer
        lr = params.lr
        beta1 = params.beta1
        beta2 = params.beta2
        #p_view = self.model.state_dict()
        self.encoder_opt = torch.optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=params.weight_decay)
        self.encoder_opt = nn.DataParallel(self.encoder_opt).module

        self.encoder_scheduler = get_scheduler(self.encoder_opt, params)

        # set loss fn
        if self.ispretrain:
            self.param_recon_criterion = get_criterion(
                params['pretrain_loss_fn'])

        # Network weight initialization
        self.model.apply(weights_init(params.init))

        self.transformer = mm2px.JointTransfomer('BB')
Example #4
0
my_transform = Compose([Scale((image_scale, image_scale), Image.BILINEAR), ToTensor()])
root_dir = '/mnt/ext_toshiba/rgb2mesh/DataSets/panoptic-toolbox'
dataset_list = ['171026_pose3']

train_dataset = PanopticSet(root_dir, dataset_list=dataset_list, 
    image_size=image_scale, data_par='train', img_transform=my_transform)
valid_dataset = PanopticSet(root_dir, dataset_list=dataset_list, 
    image_size=image_scale, data_par='valid', img_transform=my_transform)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size,
    shuffle=True)
valid_loader = data.DataLoader(dataset=valid_dataset, batch_size=batch_size,
    shuffle=False)
loaders = {'train':train_loader, 'valid':valid_loader}

model = torch.nn.DataParallel(resnet34_Mano(input_option=input_option))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

if init_model:
    load_model_path = os.path.join('data', 'model-' + str(input_option) + '.pth')
    model.load_state_dict(torch.load(load_model_path))
    prev_epoch = 0
else:
    load_model_path = os.path.join(dataset_list[0], 'best_model_ep10.pth')
    loaded_state = torch.load(load_model_path)
    prev_epoch = loaded_state['epoch']
    print('prev_epoch = ', prev_epoch)
    model.load_state_dict(loaded_state['model_state_dict'])
    optimizer.load_state_dict(loaded_state['optimizer_state_dict'])
    scheduler.load_state_dict(loaded_state['scheduler_state_dict'])