Пример #1
0
 def __init__(self,
              weight_fn,
              camK,
              res_x,
              res_y,
              obj_param,
              th_ransac=3.0,
              th_outlier=[0.1, 0.2, 0.3],
              th_inlier=0.1,
              box_size=1.5,
              dist_coeff=None,
              backbone="paper",
              **kwargs):
     self.camK = camK
     self.res_x = res_x
     self.res_y = res_y
     self.th_ransac = th_ransac
     self.th_o = th_outlier
     self.th_i = th_inlier
     self.obj_scale = obj_param[:3]  #x,y,z
     self.obj_ct = obj_param[3:]  #x,y,z
     self.box_size = box_size
     self.dist_coeff = dist_coeff
     if (backbone == 'paper'):
         self.generator_train = ae.aemodel_unet_prob(p=1.0)  #output:3gae
         self.generator_train.load_weights(weight_fn)
     elif (backbone == 'resnet50'):
         self.generator_train = load_model(weight_fn)
            if (fn_temp.startswith("pix2pose" + ".")
                    and fn_temp.endswith("hdf5")):
                temp_split = fn_temp.split(".")
                epoch_split = temp_split[1].split("-")
                epoch_split2 = epoch_split[0].split("_")
                epoch_temp = int(epoch_split2[0])
                if (epoch_temp > recent_epoch):
                    recent_epoch = epoch_temp
                    weight_fn = fn_temp
        if os.path.exists(os.path.join(root,
                                       "inference.hdf5")) and pass_exists:
            print("A converted file exists in ",
                  os.path.join(root, "inference.hdf5"))
            continue
        if (weight_fn != ""):
            generator_train = ae.aemodel_unet_prob(p=1.0)
            discriminator = ae.DCGAN_discriminator()
            imsize = 128
            dcgan_input = Input(shape=(imsize, imsize, 3))
            dcgan_target = Input(shape=(imsize, imsize, 3))
            prob_gt = Input(shape=(imsize, imsize, 1))
            gen_img, prob = generator_train(dcgan_input)
            recont_l = ae.transformer_loss(
                [np.eye(3)])([gen_img, dcgan_target, prob_gt, prob_gt])
            disc_out = discriminator(gen_img)
            dcgan = Model(inputs=[dcgan_input, dcgan_target, prob_gt],
                          outputs=[recont_l, disc_out, prob])
            print("load recent weights from ", os.path.join(root, weight_fn))
            dcgan.load_weights(os.path.join(root, weight_fn))

            print("save recent weights to ",