示例#1
0
 def __init__(self):
     self.opt = TrainOptions().parse()
     self.model = ModelsFactory.get_by_name(self.opt)
     self.tb_visualizer = TBVisualizer(self.opt)
     if self.get_training_data():
         self.setup_train_test_sets()
         self.train()
示例#2
0
def main():
    opts = TestOptions().parse()
    if not os.path.isdir(opts.output_dir):
        os.makedirs(opts.output_dir)

    result_dir = ospj(opts.output_dir, "results_wild_ganimation")
    os.makedirs(result_dir, exist_ok=True)

    model = ModelsFactory.get_by_name(opts.model, opts)
    model.set_eval()

    test_loader = get_dataloader(opts.data_dir, 'wild_images', 'test',
                                 opts.image_size, opts.selected_attrs, 1)

    test_epoch = 0
    all_sre_val = 0.0
    for test_batch_idx, test_data_batch in enumerate(test_loader):

        org_img, org_attr = test_data_batch['real_img'].clone().detach(), test_data_batch['real_cond'].clone().detach()

        img, attr, _ = change_hair_color_target(org_img.clone().detach(), org_attr.clone().detach(), opts.selected_attrs)
        test_data_batch['real_img'] = img
        test_data_batch['desired_cond'] = attr
        test_data_batch['real_cond'] = attr
        save_file = ospj(result_dir, f"test_epoch_{test_epoch}_wild_batch_{test_batch_idx}_test_hair_color.png")
        test_and_save(model, test_data_batch, save_file)


        # skin color
        img, attr, _ = change_skin_color_target(org_img.clone().detach(), org_attr.clone().detach(), opts.selected_attrs)
        test_data_batch['real_img'] = img
        test_data_batch['desired_cond'] = attr
        test_data_batch['real_cond'] = attr
        save_file = ospj(result_dir, f"test_epoch_{test_epoch}_wild_batch_{test_batch_idx}_test_skin_color.png")
        test_and_save(model, test_data_batch, save_file)

        # beard
        img, attr, _ = change_beard_target(org_img.clone().detach(), org_attr.clone().detach(), opts.selected_attrs)
        test_data_batch['real_img'] = img
        test_data_batch['desired_cond'] = attr
        test_data_batch['real_cond'] = attr
        save_file = ospj(result_dir, f"test_epoch_{test_epoch}_wild_batch_{test_batch_idx}_test_beard.png")
        test_and_save(model, test_data_batch, save_file)

        # one attribute
        img, attr, _ = change_one_attr_target(org_img.clone().detach(), org_attr.clone().detach(), opts.selected_attrs)
        test_data_batch['real_img'] = img
        test_data_batch['desired_cond'] = attr
        test_data_batch['real_cond'] = attr
        save_file = ospj(result_dir, f"test_epoch_{test_epoch}_wild_batch_{test_batch_idx}_test_one.png")
        test_and_save(model, test_data_batch, save_file)

        test_data_batch['real_img'] = org_img
        test_data_batch['desired_cond'] = org_attr
        test_data_batch['real_cond'] = org_attr
        sre = self_rec(model, test_data_batch)
        # print(test_batch_idx, sre)
        all_sre_val += sre.item()

    print("sre:", all_sre_val / len(test_loader))
示例#3
0
 def __init__(self):
     self._opt = TrainOptions().parse()
     PRESET_VARS = PATH(self._opt)
     self._model = ModelsFactory.get_by_name(self._opt.model_name,
                                             self._opt)
     train_transforms = self._model.resnet50.backbone.augment_transforms
     val_transforms = self._model.resnet50.backbone.compose_transforms
     self.training_dataloaders = Multitask_DatasetDataLoader(
         self._opt, train_mode='Train', transform=train_transforms)
     self.training_dataloaders = self.training_dataloaders.load_multitask_train_data(
     )
     self.validation_dataloaders = Multitask_DatasetDataLoader(
         self._opt, train_mode='Validation', transform=val_transforms)
     self.validation_dataloaders = self.validation_dataloaders.load_multitask_val_test_data(
     )
     print("Traning Tasks:{}".format(self._opt.tasks))
     actual_bs = self._opt.batch_size * len(self._opt.tasks)
     print("The actual batch size is {}*{}={}".format(
         self._opt.batch_size, len(self._opt.tasks), actual_bs))
     print("Training sets: {} images ({} images per task)".format(
         len(self.training_dataloaders) * actual_bs,
         len(self.training_dataloaders) * self._opt.batch_size))
     print("Validation sets")
     for task in self._opt.tasks:
         data_loader = self.validation_dataloaders[task]
         print("{}: {} images".format(
             task,
             len(data_loader) * self._opt.batch_size *
             len(self._opt.tasks)))
     self.visual_dict = {
         'training': pd.DataFrame(),
         'validation': pd.DataFrame()
     }
     self._train()
示例#4
0
文件: test.py 项目: iGuaZi/GANimation
 def __init__(self, opt):
     self._opt = opt
     self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
     self._model.set_eval()
     self._transform = transforms.Compose([transforms.ToTensor(),
                                           transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                                std=[0.5, 0.5, 0.5])
                                           ])
示例#5
0
文件: test.py 项目: hanezu/GANimation
 def __init__(self, opt):
     self._opt = opt
     self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
     self._model.set_eval()
     self._transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
     ])
示例#6
0
    def __init__(self, prob_threshold=0.5, do_filter_prob=True):
        self._prob_treshold = prob_threshold
        self._do_filter_prob = do_filter_prob
        self._detected_in_previous_frame = False

        self._opt = TestOptions().parse()  # parse model parameters
        self._img2tensor = self._create_img_transform(
        )  # map RGB cv2 image to Pytorch tensor
        self._model = ModelsFactory.get_by_name(self._opt.model,
                                                self._opt)  # get model
        self._model.set_eval()  # set model in test mode
示例#7
0
 def __init__(self):
     self._opt = TestOptions().parse()
     self._model = ModelsFactory.get_by_name(self._opt.model_name, self._opt)
     test_data_file = PRESET_VARS.Aff_wild2.test_data_file
     self.test_data_file = pickle.load(open(test_data_file, 'rb'))
     self.save_dir = self._opt.save_dir
     if not os.path.exists(self.save_dir):
         os.makedirs(self.save_dir)
     if self._opt.mode == 'Test':
         self._test()
     else:
         raise ValueError("do not call test.py with validation mode.")
示例#8
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        self.data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        self.data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
        self._dataset_train = self.data_loader_train.load_data()
        self._dataset_test = self.data_loader_test.load_data()

        self._dataset_train_size = len(self.data_loader_train)
        self._dataset_test_size = len(self.data_loader_test)
        print('#train images = %s' % self._dataset_train_size)
        print('#test images = %s' % self._dataset_test_size)
        
        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        
        self._train()
示例#9
0
    def __init__(self, opt):
        self._opt = opt
        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._model.set_eval()
        self.ids = self.read_ids(os.path.join(self._opt.data_dir, self._opt.input_file))
        self.conds = self.read_conds(os.path.join(self._opt.data_dir, self._opt.aus_file))
        self.ids = list(set(self.ids).intersection(set(self.conds.keys())))
        self.layers = self._opt.layers
        print('#images: ', len(self.ids))
        if not os.path.exists(self._opt.output_dir):
            os.makedirs(self._opt.output_dir)

        self._transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                                   std=[0.5, 0.5, 0.5])
                                              ])
示例#10
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
示例#11
0
def main():
    opts = TestOptions().parse()
    if not os.path.isdir(opts.output_dir):
        os.makedirs(opts.output_dir)

    eval_part1_loader = get_dataloader(opts.data_dir, 'celebahq_ffhq_fake', 'eval_part1',
                                       opts.image_size, opts.selected_attrs, len(opts.selected_attrs))
    eval_part2_loader = get_dataloader(opts.data_dir, 'celebahq_ffhq_fake', 'eval_part2', 
                                       opts.image_size, opts.selected_attrs, len(opts.selected_attrs))

    model = ModelsFactory.get_by_name(opts.model, opts)
    model.set_eval()

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[opts.dims]
    fid_model = InceptionV3([block_idx])
    fid_model = torch.nn.DataParallel(fid_model).cuda()
    fid_model = fid_model.eval()

    # NOTE: here we hard code to resnet18, we construct the resnet with selected attributes
    attr_pred_model = resnet18(pretrained=True, num_attributes=len(opts.selected_attrs))
    attr_model_ckpt = ospj('eval_attr/checkpoints_select_no_extra/model_best.pth.tar')  # for local
    assert os.path.isfile(attr_model_ckpt), f"checkpoint file {attr_model_ckpt} for attribute prediction not found!"
    print(f"=> loading attribute checkpoint '{attr_model_ckpt}'")
    checkpoint = torch.load(attr_model_ckpt, map_location=torch.device("cpu"))
    attr_pred_model.load_state_dict(checkpoint['state_dict'])
    attr_pred_model = torch.nn.DataParallel(attr_pred_model).cuda()
    attr_pred_model = attr_pred_model.eval()
    # attr_pred_model_cpu = attr_pred_model.module
    print(f"=> loaded attribute checkpoint '{attr_model_ckpt}' (epoch {checkpoint['epoch']})")

    fid_score = predict_fid_score(opts, eval_part1_loader, eval_part2_loader, fid_model, model)
    all_attrs_avg, each_attr_avg = predict_attr_score(opts, eval_part2_loader, attr_pred_model, model)

    eval_dict = {}
    eval_dict["FID"] = fid_score
    eval_dict["Attribute_Average"] = all_attrs_avg
    for k, v in eval_dict.items():
        # writer.add_scalar(f"Eval/{k}", v, epoch)
        print(f"Eval {k}: {v}")

    all_attr_eval_dict = {}
    for attr_name, attr_pred in zip(opts.selected_attrs, each_attr_avg):
        all_attr_eval_dict[attr_name] = attr_pred
        print(f"Eval {attr_name}: {attr_pred}")
示例#12
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        # data_loader_train = CustomDatasetDataLoader(self._opt, is_for_train=True)
        # data_loader_test = CustomDatasetDataLoader(self._opt, is_for_train=False)
        data_loader_train = get_dataloader(self._opt.data_dir, img_size=self._opt.image_size, selected_attrs=self._opt.selected_attrs, mode='train', batch_size=self._opt.batch_size)
        data_loader_test = get_dataloader(self._opt.data_dir, img_size=self._opt.image_size, selected_attrs=self._opt.selected_attrs, mode='val', batch_size=self._opt.batch_size)

        self._dataset_train = data_loader_train
        self._dataset_test = data_loader_test

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train image batches = %d' % self._dataset_train_size)  # dataloader size, not dataset size
        print('#test image batches = %d' % self._dataset_test_size)    # dataloader size, not dataset size

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._train()
示例#13
0
    def __init__(self):
        self._opt = TrainOptions().parse()

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        data_loader_train = CustomDatasetDataLoader(self._opt, mode='train')
        data_loader_val = CustomDatasetDataLoader(self._opt, mode='val')
        #data_loader_train = CustomDatasetDataLoader(self._opt, mode='test')
        #data_loader_val = CustomDatasetDataLoader(self._opt, mode='test')

        self._dataset_train = data_loader_train.load_data()
        self._dataset_val = data_loader_val.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_val_size = len(data_loader_val)
        print('#train images = %d' % self._dataset_train_size)
        print('#val images = %d' % self._dataset_val_size)

        self._train()
示例#14
0
文件: test.py 项目: Lynn0306/LEDVDI
    def __init__(self):
        self._opt = TestOpt().parse()
        self.cal_time = 0.0
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_test = data_loader_test.load_data()

        self._dataset_test_size = len(data_loader_test)
        print('# Test videos : %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)

        self.output_dir = os.path.expanduser(self._opt.output_dir)
        if os.path.exists(self.output_dir):
            shutil.rmtree(self.output_dir)
        os.makedirs(self.output_dir)
        print('# Output dir: ', self.output_dir)

        self._test()
示例#15
0
 def __init__(self, opt):
     self._opt = opt
     self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
     self._model.set_eval()
     mean = [0.5, 0.5, 0.5]
     std = [0.5, 0.5, 0.5]
     self._mean = mean
     self._std = std
     self._transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Resize(64, transforms.InterpolationMode.BICUBIC),
         # transforms.RandomHorizontalFlip(p=1),
         transforms.Normalize(mean=mean, std=std),
     ])
     self._de_transform = transforms.Compose([
         transforms.Normalize(
             mean=[-m / s for m, s in zip(self._mean, self._std)],
             std=[1. / s for s in self._std]),
         transforms.ToPILImage(),
         transforms.Resize(85, transforms.InterpolationMode.BICUBIC),
     ])
 def __init__(self):
     self._opt = TestOptions().parse()
     PRESET_VARS = PATH()
     self._model = ModelsFactory.get_by_name(self._opt.model_name,
                                             self._opt)
     val_transforms = self._model.resnet50_GRU.backbone.backbone.compose_transforms
     self.validation_dataloaders = Multitask_DatasetDataLoader(
         self._opt, train_mode=self._opt.mode, transform=val_transforms)
     self.validation_dataloaders = self.validation_dataloaders.load_multitask_val_test_data(
     )
     print("{} sets".format(self._opt.mode))
     for task in self._opt.tasks:
         data_loader = self.validation_dataloaders[task]
         print("{}: {} images".format(
             task,
             len(data_loader) * self._opt.batch_size *
             len(self._opt.tasks) * self._opt.seq_len))
     if self._opt.mode == 'Validation':
         self._validate()
     else:
         raise ValueError("do not call val.py with test mode.")
示例#17
0
    def __init__(self):
        self._opt = TrainOptions().parse()
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_test_size = len(data_loader_test)
        print('#train images = %d' % self._dataset_train_size)
        print('#test images = %d' % self._dataset_test_size)
        print('TRAIN IMAGES FOLDER = %s' %
              data_loader_train._dataset._imgs_dir)
        print('TEST IMAGES FOLDER = %s' % data_loader_test._dataset._imgs_dir)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)
        self._writer = SummaryWriter()

        self._input_imgs = torch.empty(0, 3, self._opt.image_size,
                                       self._opt.image_size)
        self._fake_imgs = torch.empty(0, 3, self._opt.image_size,
                                      self._opt.image_size)
        self._rec_real_imgs = torch.empty(0, 3, self._opt.image_size,
                                          self._opt.image_size)
        self._fake_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._fake_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                           self._opt.image_size)
        self._rec_real_imgs_mask = torch.empty(0, 3, self._opt.image_size,
                                               self._opt.image_size)
        self._cyc_imgs_unmasked = torch.empty(0, 3, self._opt.image_size,
                                              self._opt.image_size)
        self._real_conds = list()
        self._desired_conds = list()

        self._train()
示例#18
0
    def _train(self):
        self._total_steps = self._opt.load_epoch * self._dataset_train_size
        self._iters_per_epoch = self._dataset_train_size / self._opt.batch_size
        self._last_display_time = None
        self._last_save_latest_time = None
        self._last_print_time = time.time()

        self._opt.load_epoch = -1
        for i_epoch in range(
                self._opt.load_epoch + 1,
                self._opt.nepochs_no_decay + self._opt.nepochs_decay + 1):
            self._opt.load_epoch = i_epoch
            if (i_epoch != 0):
                self._model.load()
            else:
                self._model = ModelsFactory.get_by_name(
                    self._opt.model, self._opt)
            epoch_start_time = time.time()

            # train epoch
            self._train_epoch(i_epoch)

            # save model
            print('showing the model at the end of epoch %d, iters %d' %
                  (i_epoch, self._total_steps))
            #self._model.save(i_epoch)

            # print epoch info
            time_epoch = time.time() - epoch_start_time
            print(
                'End of 2 batches in epoch %d / %d \t Time Taken: %d sec (%d min or %d h)'
                %
                (i_epoch, self._opt.nepochs_no_decay + self._opt.nepochs_decay,
                 time_epoch, time_epoch / 60, time_epoch / 3600))

            # update learning rate
            if i_epoch > self._opt.nepochs_no_decay:
                self._model.update_learning_rate()
示例#19
0
文件: test.py 项目: wx-b/Multi-FinGAN
    def setup_multi_fingan(self):
        self.device = torch.device('cuda:{}'.format(
            self.opt.gpu_ids[0])) if self.opt.gpu_ids[0] != -1 and torch.cuda.is_available() else torch.device('cpu')
        self.save_path = self.opt.save_folder+self.opt.test_set
        self.num_grasps_to_sample = self.opt.num_grasps_to_sample
        # Let's set batch size to 1 as we generate grasps from one viewpoint only
        self.opt.batch_size = 1

        self.opt.object_finger_contact_threshold = 0.004
        self.opt.optimize_finger_tip = True
        self.opt.num_viewpoints_per_object = self.opt.num_viewpoints_per_object_test

        self.opt.n_threads_train = self.opt.n_threads_test

        self.model = ModelsFactory.get_by_name(self.opt)

        data_loader_test = CustomDatasetDataLoader(self.opt, mode='val')
        self.dataset_test = data_loader_test.load_data()
        self.dataset_test_size = len(data_loader_test)
        print('#test images = %d' % self.dataset_test_size)
        self.model.set_eval()

        self.hand_in_parts = self.calculate_number_of_hand_vertices_per_part(self.model.barrett_layer)
示例#20
0
    def __init__(self, ):
        self._opt = TrainOptions().parse()
        self._save_path = os.path.join(self._opt.checkpoints_dir,
                                       self._opt.name)
        util.mkdirs(self._save_path)
        self._log_path_train = os.path.join(self._save_path,
                                            'loss_log_train.txt')
        self._log_path_val = os.path.join(self._save_path, 'loss_log_val.txt')
        self._log_path_val_metric = os.path.join(self._save_path,
                                                 'loss_log_val_metric.txt')
        self._log_path_test = os.path.join(self._save_path,
                                           'loss_log_test.txt')
        self._log_path_test_metric = os.path.join(self._save_path,
                                                  'loss_log_test_metric.txt')
        data_loader_train = CustomDatasetDataLoader(self._opt,
                                                    is_for_train=True,
                                                    is_for_val=False)
        data_loader_val = CustomDatasetDataLoader(self._opt,
                                                  is_for_train=False,
                                                  is_for_val=True)
        data_loader_test = CustomDatasetDataLoader(self._opt,
                                                   is_for_train=False,
                                                   is_for_val=False)

        self._dataset_train = data_loader_train.load_data()
        self._dataset_val = data_loader_val.load_data()
        self._dataset_test = data_loader_test.load_data()

        self._dataset_train_size = len(data_loader_train)
        self._dataset_val_size = len(data_loader_val)
        self._dataset_test_size = len(data_loader_test)
        print('#train sequences = %d' % self._dataset_train_size)
        print('#validate sequences = %d' % self._dataset_val_size)
        print('#test sequences = %d' % self._dataset_test_size)
        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)

        self._train()
示例#21
0
                if cross_mi_info['a_pose']:
                    used_a = True

                tgt_outdir = get_dir(os.path.join(src_outdir, src_name, cvt_name(cross_v_name))) if output_dir else ''

                is_self = v_name == cross_v_name
                imitator.imitate(tgt_paths=cross_mi_info['images'], tgt_smpls=cross_mi_info['smpls'],
                                 output_dir=tgt_outdir, visualizer=visualizer)


if __name__ == "__main__":

    opt = TestOptions().parse()

    # set animator
    animator = ModelsFactory.get_by_name(opt.model, opt)

    if opt.ip:
        visualizer = VisdomVisualizer(env=opt.name, ip=opt.ip, port=opt.port)
    else:
        visualizer = None

    src_path = opt.src_path
    ref_path = opt.ref_path
    tgt_path = opt.tgt_path

    animator.animate_setup(src_path, ref_path)

    imgs_paths = []
    if os.path.isdir(tgt_path):
        imgs_paths = glob.glob(os.path.join(tgt_path, '*.jpg'))
示例#22
0
    def __init__(self):

        # TO GET THEM:
        # clusters_pose_map, clusters_rot_map, clusters_root_rot = self.get_rot_map(self._model.clusters_tensor, torch.zeros((25, 3)).cuda())
        #for i in range(25):
        #    import matplotlib.pyplot
        #    from mpl_toolkits.mplot3d import Axes3D
        #    ax = matplotlib.pyplot.figure().add_subplot(111, projection='3d')
        #    #i = 0
        #    add_group_meshs(ax, cluster_verts[i].cpu().data.numpy(), hand_faces, c='b')
        #    cam_equal_aspect_3d(ax, cluster_verts[i].cpu().data.numpy())
        #    print(i)
        #    matplotlib.pyplot.pause(1)
        #    matplotlib.pyplot.close()

        # FINGER LIMIT ANGLE:
        #self.limit_bigfinger = torch.FloatTensor([1.0222, 0.0996, 0.7302]) # 36:39
        #self.limit_bigfinger = torch.FloatTensor([1.2030, 0.12, 0.25]) # 36:39
        #self.limit_bigfinger = torch.FloatTensor([1.2, -0.4, 0.25]) # 36:39
        self.limit_bigfinger = torch.FloatTensor([1.2, -0.6, 0.25])  # 36:39
        self.limit_index = torch.FloatTensor([-0.0827, -0.4389, 1.5193])  # 0:3
        self.limit_middlefinger = torch.FloatTensor(
            [-2.9802e-08, -7.4506e-09, 1.4932e+00])  # 9:12
        self.limit_fourth = torch.FloatTensor([0.1505, 0.3769,
                                               1.5090])  # 27:30
        self.limit_small = torch.FloatTensor([-0.6235, 0.0275,
                                              1.0519])  # 18:21
        if torch.cuda.is_available():
            self.limit_bigfinger = self.limit_bigfinger.cuda()
            self.limit_index = self.limit_index.cuda()
            self.limit_middlefinger = self.limit_middlefinger.cuda()
            self.limit_fourth = self.limit_fourth.cuda()
            self.limit_small = self.limit_small.cuda()

        self._bigfinger_vertices = [
            697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709,
            710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722,
            723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735,
            736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748,
            749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761,
            762, 763, 764, 765, 766, 767, 768
        ]

        self._indexfinger_vertices = [
            46, 47, 48, 49, 56, 57, 58, 59, 86, 87, 133, 134, 155, 156, 164,
            165, 166, 167, 174, 175, 189, 194, 195, 212, 213, 221, 222, 223,
            224, 225, 226, 237, 238, 272, 273, 280, 281, 282, 283, 294, 295,
            296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308,
            309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321,
            322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334,
            335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347,
            348, 349, 350, 351, 352, 353, 354, 355
        ]

        self._middlefinger_vertices = [
            356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 372,
            373, 374, 375, 376, 377, 381, 382, 385, 386, 387, 388, 389, 390,
            391, 392, 393, 394, 395, 396, 397, 398, 400, 401, 402, 403, 404,
            405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417,
            418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430,
            431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
            444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456,
            457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467
        ]

        self._fourthfinger_vertices = [
            468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 482,
            483, 484, 485, 486, 487, 491, 492, 495, 496, 497, 498, 499, 500,
            501, 502, 503, 504, 505, 506, 507, 508, 511, 512, 513, 514, 515,
            516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528,
            529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541,
            542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554,
            555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567,
            568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578
        ]

        self._smallfinger_vertices = [
            580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 598,
            599, 600, 601, 602, 603, 609, 610, 613, 614, 615, 616, 617, 618,
            619, 620, 621, 622, 623, 624, 625, 626, 628, 629, 630, 631, 632,
            633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645,
            646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658,
            659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671,
            672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684,
            685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695
        ]

        self._opt = TestOptions().parse()
        #assert self._opt.load_epoch > 0, 'Use command --load_epoch to indicate the epoch you want to load - and choose a trained model'

        # Let's set batch size at 2 since we're only getting one image so far
        self._opt.batch_size = 1

        self._opt.n_threads_train = self._opt.n_threads_test
        data_loader_test = CustomDatasetDataLoader(self._opt, mode='test')
        self._dataset_test = data_loader_test.load_data()
        self._dataset_test_size = len(data_loader_test)
        print('#test images = %d' % self._dataset_test_size)

        self._model = ModelsFactory.get_by_name(self._opt.model, self._opt)
        self._tb_visualizer = TBVisualizer(self._opt)

        self._total_steps = self._dataset_test_size
        self._display_visualizer_test(20, self._total_steps)
示例#23
0
 def _train(self):
     self._total_steps = self._opt.load_epoch * len(
         self.training_dataloaders) * self._opt.batch_size
     self._last_display_time = None
     self._last_save_latest_time = None
     self._last_print_time = time.time()
     self._current_val_acc = 0.
     if len(self._opt.pretrained_teacher_model) == 0:
         for i_epoch in range(self._opt.load_epoch + 1,
                              self._opt.teacher_nepochs + 1):
             epoch_start_time = time.time()
             self._model.get_current_LR()
             # train epoch
             self._train_epoch(i_epoch)
             self.training_dataloaders.reset()
             val_acc = self._validate(i_epoch)
             if val_acc > self._current_val_acc:
                 print("validation acc improved, from {:.4f} to {:.4f}".
                       format(self._current_val_acc, val_acc))
                 print('saving the model at the end of epoch %d, steps %d' %
                       (i_epoch, self._total_steps))
                 self._model.save('teacher')
                 self._current_val_acc = val_acc
             self.save_visual_dict('teacher')
             self.save_logging_image('teacher')
             # print epoch info
             time_epoch = time.time() - epoch_start_time
             print(
                 'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)'
                 % (i_epoch, self._opt.teacher_nepochs, time_epoch,
                    time_epoch / 60, time_epoch / 3600))
         return
     else:
         self._model.resnet50.load_state_dict(
             torch.load(self._opt.pretrained_teacher_model))
     # record the teacher_model
     self._teacher_model = deepcopy(self._model)
     del self._model
     self._model = None
     self._teacher_model.set_eval()
     for i_student in range(self._opt.n_students):
         self._current_val_acc = 0.
         self._model = ModelsFactory.get_by_name(self._opt.model_name,
                                                 self._opt)  # re-initialize
         self.visual_dict = {
             'training': pd.DataFrame(),
             'validation': pd.DataFrame()
         }
         for i_epoch in range(1, self._opt.student_nepochs + 1):
             epoch_start_time = time.time()
             self._model.get_current_LR()
             self._train_epoch_kd(i_epoch)
             self.training_dataloaders.reset()
             val_acc = self._validate(i_epoch)
             if val_acc > self._current_val_acc:
                 print("validation acc improved, from {:.4f} to {:.4f}".
                       format(self._current_val_acc, val_acc))
                 print('saving the model at the end of epoch %d, steps %d' %
                       (i_epoch, self._total_steps))
                 self._model.save('student_{}'.format(i_student))
                 self._current_val_acc = val_acc
             self.save_visual_dict('student_{}'.format(i_student))
             self.save_logging_image('student_{}'.format(i_student))
             # print epoch info
             time_epoch = time.time() - epoch_start_time
             print(
                 'End of epoch %d / %d \t Time Taken: %d sec (%d min or %d h)'
                 % (i_epoch, self._opt.student_nepochs, time_epoch,
                    time_epoch / 60, time_epoch / 3600))