예제 #1
0
    def load_cache(self):
        """
        load the cache file if exists or create the cache file
        """
        cache = os.path.join(
            self.dataroot,
            'cache_%s_%d.db' % (self.mode, self.video_len * self.every_nth))
        if cache is not None and os.path.exists(cache):
            # load the cache
            with open(cache, 'rb') as f:
                self.lines, self.lengths, self.actor_set, self.action_set = pickle.load(
                    f)
        else:
            # build the cache file
            self.lines, self.lengths, self.actor_set, self.action_set = self.build_dataset(
            )
            makedir(self.textroot)
            text_file = os.path.join(
                self.textroot, '%s_list_%d.txt' %
                (self.mode, self.video_len * self.every_nth))
            with open(text_file, 'w+') as f:
                f.writelines(self.lines)

            if cache is not None:
                with open(cache, 'wb') as f:
                    pickle.dump((self.lines, self.lengths, self.actor_set,
                                 self.action_set), f)

        self.cumsum = np.cumsum([0] + self.lengths)
        print("Total number of frames {}".format(np.sum(self.lengths)))
예제 #2
0
    def save_batch(self, current_iter, names=None, start_idx=0):
        """
        save the batch for the generation conditioned on the first frame
        :param current_iter: int, the current iteration
        :param names: the name of videos where the first frame is from
        :param start_idx: int, the start index of the current batch
        :return: output_dir: the path of the output folder
        """
        output_dir = os.path.join(self.output_dir, 'evaluation',
                                  str(current_iter))
        makedir(output_dir)

        video = self.video_x_p.cpu().data.numpy().transpose((0, 2, 1, 3, 4))
        self.save_video(video,
                        output_dir,
                        self.categories.cpu().data.numpy(),
                        self.actors.cpu().data.numpy(),
                        names=names,
                        start_idx=start_idx)
        return output_dir
    def full_test(self, cls_id, batch_size, video_len, current_iter, var_name, start_idx=0, is_eval=False, rm_npy=False,
                  get_seq=False, get_mask=False):
        """
        :param cls_id:  int, the action index at test
        :param batch_size: int
        :param video_len: int, the desired length of the video
        :param current_iter: int, the current iteration so far
        :param var_name: str, the variable name for saving or tensorboard visualizing
        :param start_idx: int, the start index of the current batch
        :param is_eval: bool, specify when evaluating
        :param rm_npy: bool, specify to remove all npy files in the output folder
        :param get_seq: bool, specify to save the video sequence
        :param get_mask: bool, specify to visualize the mask
        :return: output_dir: str, the output path
        """
        # create the category matrix for the test class
        cat = cls_id * np.ones((batch_size,)).astype('int')
        if torch.cuda.is_available():
            self.categories = Variable(torch.from_numpy(cat)).cuda()
        else:
            self.categories = Variable(torch.from_numpy(cat))

        # generate the video with size [batch_size, video_len, c, h, w]
        torch.set_grad_enabled(False)
        video, masks = self.networks['generator'].full_test(self.categories, video_len+2)
        torch.set_grad_enabled(True)
        # heat up the generator for two steps
        video = video[:, 2:]

        # create the output directory
        if is_eval:
            output_dir = os.path.join(self.output_dir, 'evaluation', str(current_iter))
        else:
            output_dir = os.path.join(self.output_dir, 'validation', str(current_iter))
        makedir(output_dir)

        # remove the existing npy files
        if rm_npy:
            os.system('rm %s' % os.path.join(output_dir, '*.npy'))

        # save original output to npy file
        # video_np [batch_size, video_len, c,  h, w]
        video_np = video.cpu().data.numpy().clip(-1, 1)
        self.save_video(video_np, output_dir, self.categories, start_idx=start_idx)

        # saving to tensorboard during the validation
        if not is_eval:
            # save to tensorboard
            # [batch_size, video_len, c, h, w]
            video = torch.clamp((video.permute(0, 2, 1, 3, 4) + 1)/2, 0, 1)
            self.writer.add_video(var_name, video, current_iter)

        # save the video sequences to the output folder
        if get_seq:
            video_seqs = ((video_np.transpose(0, 1, 3, 4, 2) + 1)/2 * 255).astype('uint8')
            video_seqs = np.concatenate(np.split(video_seqs, video_len, axis=1), axis=3).squeeze()

            img_dir = os.path.join(output_dir, 'imgs')
            makedir(img_dir)
            for v_idx, seq in enumerate(video_seqs):
                filename = os.path.join(img_dir, '%s_%03d.png' % (var_name, start_idx + v_idx))
                cv2.imwrite(filename, seq[:, :, ::-1])

        # save masks to the output folder
        if get_mask:
            mask_8 = []
            mask_16 = []
            mask_32 = []
            mask_64 = []
            for frame_mask in masks:
                if self.layers >= 4:
                    mask_8.append(frame_mask[0].cpu().numpy().squeeze().clip(0, 1))
                if self.layers >= 3:
                    mask_16.append(frame_mask[1].cpu().numpy().squeeze().clip(0, 1))
                if self.layers >= 2:
                    mask_32.append(frame_mask[2].cpu().numpy().squeeze().clip(0, 1))
                if self.layers >= 1:
                    mask_64.append(frame_mask[3].cpu().numpy().squeeze().clip(0, 1))
            if self.layers >= 4:
                mask_8 = np.concatenate(mask_8[2:], axis=2)
            if self.layers >= 3:
                mask_16 = np.concatenate(mask_16[2:], axis=2)
            if self.layers >= 2:
                mask_32 = np.concatenate(mask_32[2:], axis=2)
            if self.layers >= 1:
                mask_64 = np.concatenate(mask_64[2:], axis=2)
            mask_dir = os.path.join(output_dir, 'masks')
            makedir(mask_dir)
            for v_idx in range(batch_size):
                if self.layers >= 4:
                    filename = os.path.join(mask_dir, '%s_%03d_mask_8.png' % (var_name, start_idx + v_idx))
                    cv2.imwrite(filename, (mask_8[v_idx] * 255).astype('uint8'))
                if self.layers >= 3:
                    filename = os.path.join(mask_dir, '%s_%03d_mask_16.png' % (var_name, start_idx + v_idx))
                    cv2.imwrite(filename, (mask_16[v_idx] * 255).astype('uint8'))
                if self.layers >= 2:
                    filename = os.path.join(mask_dir, '%s_%03d_mask_32.png' % (var_name, start_idx + v_idx))
                    cv2.imwrite(filename, (mask_32[v_idx] * 255).astype('uint8'))
                if self.layers >= 1:
                    filename = os.path.join(mask_dir, '%s_%03d_mask_64.png' % (var_name, start_idx + v_idx))
                    cv2.imwrite(filename, (mask_64[v_idx] * 255).astype('uint8'))

        return output_dir
예제 #4
0
def main():
    # ********************************************************************
    # ****************** create folders and print options ****************
    # ********************************************************************
    opt, gen_args = TestOptions().parse()
    makedir(opt.output_dir)
    makedir(opt.log_dir)
    listopt(opt)
    with open(os.path.join(opt.log_dir, 'test_opt.txt'), 'w+') as f:
        listopt(opt, f)

    # ********************************************************************
    # ******************** Prepare the dataloaders ***********************
    # ********************************************************************
    image_transforms = transforms.Compose([
        transforms.ToTensor(),
        lambda x: x[:opt.n_channels, ::],
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    video_transforms = functools.partial(video_transform,
                                         image_transform=image_transforms)

    if opt.dataset == 'Weizmann':
        valset = WeizmannDataset(opt.dataroot,
                                 opt.textroot,
                                 opt.video_length,
                                 opt.image_size,
                                 opt.every_nth,
                                 False,
                                 'Test',
                                 mini_clip=opt.miniclip)
    elif opt.dataset == 'MUG':
        valset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length,
                            opt.image_size, opt.every_nth, False, 'Test')
    elif opt.dataset == 'SynAction':
        valset = SynActionDataset(opt.dataroot, opt.textroot, opt.video_length,
                                  opt.image_size, False, 'Test')
    else:
        raise NotImplementedError('%s dataset is not supported' % opt.dataset)

    # get the validate dataloader
    video_valset = VideoDataset(valset,
                                opt.video_length,
                                every_nth=opt.every_nth,
                                transform=video_transforms)
    video_val_loader = DataLoader(video_valset,
                                  batch_size=opt.batch_size,
                                  drop_last=False,
                                  num_workers=2,
                                  shuffle=False)

    # ********************************************************************
    # ********************Create the environment *************************
    # ********************************************************************
    gen_args['num_categories'] = len(valset.action_set)
    if opt.model == 'SGVAN':
        environ = SGVAN(gen_args,
                        opt.checkpoint_dir,
                        opt.log_dir,
                        opt.output_dir,
                        opt.video_length,
                        valset.action_set,
                        valset.actor_set,
                        is_eval=True)
    elif opt.model == 'TwoStreamVAN':
        environ = TwoStreamVAN(gen_args,
                               opt.checkpoint_dir,
                               opt.log_dir,
                               opt.output_dir,
                               opt.video_length,
                               valset.action_set,
                               valset.actor_set,
                               is_eval=True)
    else:
        raise ValueError('Model %s is not implemented' % opt.mode)

    current_iter = environ.load(opt.which_iter, is_eval=True)
    environ.eval()

    # ********************************************************************
    # ***************************  Full test  ****************************
    # ********************************************************************
    rm_npy = True
    result_file = os.path.join(opt.log_dir, 'results.txt')
    for idx, cls_name in enumerate(valset.action_set):
        for c in range(10):
            prefix = 'none_%s' % cls_name
            if opt.model == 'SGVAN':
                output_dir = environ.full_test(idx,
                                               90,
                                               opt.video_length,
                                               current_iter,
                                               var_name=prefix,
                                               start_idx=c * 90,
                                               is_eval=True,
                                               rm_npy=rm_npy)
            elif opt.model == 'TwoStreamVAN':
                output_dir = environ.full_test(idx,
                                               90,
                                               opt.video_length,
                                               current_iter,
                                               start_idx=c * 90,
                                               var_name=prefix,
                                               is_eval=True,
                                               rm_npy=rm_npy,
                                               get_seq=opt.get_seq,
                                               get_mask=opt.get_mask)
            else:
                raise ValueError('Model %s is not implemented' % opt.model)
            rm_npy = False
    full_metrics = eval(opt, output_dir)
    with open(result_file, 'a') as f:
        f.writelines(print_dict(full_metrics, 'full_metric') + '\n')

    # ********************************************************************
    # ************************  Conditional Test *************************
    # ********************************************************************
    # # provide the first frame
    for c in tqdm.tqdm(range(opt.val_num)):
        for idx, batch in enumerate(video_val_loader):
            environ.set_inputs(batch)
            environ.video_forward(eplison=0, ae_mode='mean', is_eval=True)
            names = batch['names']
            output_dir = environ.save_batch(current_iter,
                                            names=names,
                                            start_idx=c)
    metrics = eval(opt, output_dir, data_is_actor=True)

    print(full_metrics)
    print(metrics)
    with open(result_file, 'a') as f:
        f.writelines(print_dict(metrics, 'avg_metric') + '\n')
예제 #5
0
    def __init__(self,
                 dataroot,
                 textroot,
                 video_len,
                 image_size,
                 every_nth,
                 crop,
                 mode="Train",
                 mini_clip=False):
        """
        Prepare for the data list
        :param dataroot: str, the path for the stored data
        :param textroot: str, the path to write the data list
        :param video_len: int, the length of the generated video
        :param image_size: int, the spatial size of the image
        :param every_nth: int, the frequency to sample frames
        :param crop: bool, true if random cropping
        :param mode: ['Train', 'Test']
        :param mini_clip: bool, specify if the origin video is divided into mini-clips
        """
        print(self.name())

        # parse the args
        self.dataroot = dataroot
        self.textroot = textroot
        self.video_len = video_len
        self.every_nth = every_nth
        self.crop = crop
        self.mode = mode
        self.action_set = [
            'bend', 'jack', 'pjump', 'wave1', 'wave2', 'jump', 'run', 'side',
            'skip', 'walk'
        ]
        self.actor_set = [
            'daria', 'denis', 'eli', 'ido', 'ira', 'lena', 'lyova', 'moshe',
            'shahar'
        ]
        self.image_size = image_size
        self.mini_clip = mini_clip

        # get the cache name
        if mini_clip:
            cache = os.path.join(
                self.dataroot,
                'cache_mini_%s_%d.db' % (mode, video_len * every_nth))
        else:
            cache = os.path.join(
                self.dataroot,
                'cache_%s_%d.db' % (mode, video_len * every_nth))

        # read the cache file or build the cache file
        if cache is not None and os.path.exists(cache):
            # read the cache file
            with open(cache, 'rb') as f:
                self.lines, self.lengths = pickle.load(f)
        else:
            # build the list
            self.lines, self.lengths = self.build_dataset()

            # write the readable text file
            makedir(textroot)
            if mini_clip:
                text_file = os.path.join(
                    textroot,
                    'miniclip_%s_list_%d.txt' % (mode, video_len * every_nth))
            else:
                text_file = os.path.join(
                    textroot, '%s_list_%d.txt' % (mode, video_len * every_nth))
            with open(text_file, 'w+') as f:
                f.writelines(self.lines)

            # write the cache file
            if cache is not None:
                with open(cache, 'wb') as f:
                    pickle.dump((self.lines, self.lengths), f)

        # get the total frames
        self.cumsum = np.cumsum([0] + self.lengths)
        print("Total number of frames {}".format(np.sum(self.lengths)))
예제 #6
0
def eval(eval_args=None, output_dir=None, data_is_actor=False):

    # ********************************************************************
    # ****************** create folders and print options ****************
    # ********************************************************************
    if eval_args is None:
        opt = TestOptions().parse()
        # make up all dirs and print options
        listopt(opt)
        makedir(opt.log_dir)
        with open(os.path.join(opt.log_dir, 'test_options.txt'), 'w+') as f:
            listopt(opt, f)
    else:
        opt = make_opt(eval_args, output_dir, data_is_actor)

    # ********************************************************************
    # ******************** Prepare the dataloaders ***********************
    # ********************************************************************
    print('define image and video transformation')
    # get the dataloader
    image_transforms = transforms.Compose([
        transforms.ToTensor(),
        lambda x: x[:opt.n_channels, ::],
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    video_transforms = functools.partial(video_transform,
                                         image_transform=image_transforms)

    if opt.dataset == 'Generated_video':
        if opt.test_dataset == 'Weizmann':
            test_dataset = WeizmannDataset(opt.test_dataroot, opt.textroot,
                                           opt.video_length, opt.image_size,
                                           opt.every_nth, False, 'Test')
        elif opt.test_dataset == 'MUG':
            test_dataset = MUGDataset(opt.test_dataroot, opt.textroot,
                                      opt.video_length, opt.image_size,
                                      opt.every_nth, False, 'Test')
        elif opt.test_dataset == 'SynAction':
            test_dataset = SynActionDataset(opt.test_dataroot, opt.textroot,
                                            opt.video_length, opt.image_size,
                                            False, 'Test')
        else:
            raise NotImplementedError('%s is not implemented' % opt.dataset)
        action_set = test_dataset.action_set
        actor_set = test_dataset.actor_set
        valset = GeneratedDataset(opt.dataroot, opt.data_is_actor, actor_set,
                                  action_set)
        video_valset = valset
    else:
        if opt.dataset == 'Weizmann':
            print('create the video dataloader')
            # dataset, val_dataset = None, None
            valset = WeizmannDataset(opt.dataroot, opt.textroot,
                                     opt.video_length, opt.image_size,
                                     opt.every_nth, False, 'Test')
        elif opt.dataset == 'MUG':
            valset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length,
                                opt.image_size, opt.every_nth, False, 'Test')
        elif opt.dataset == 'MUG2':
            trainset = MUGDataset_2(opt.dataroot, opt.textroot,
                                    opt.video_length, opt.image_size,
                                    opt.every_nth, opt.crop, 'Train')
            print(trainset.action_set)
            valset = MUGDataset_2(opt.dataroot, opt.textroot, opt.video_length,
                                  opt.image_size, opt.every_nth, False, 'Test')
            print(valset.action_set)
        elif opt.dataset == 'SynAction':
            valset = SynActionDataset(opt.dataroot, opt.textroot,
                                      opt.video_length, opt.image_size, False,
                                      'Test')
        else:
            raise NotImplementedError('%s is not implemented' % opt.dataset)
        video_valset = VideoDataset(valset,
                                    opt.video_length,
                                    every_nth=opt.every_nth,
                                    transform=video_transforms)
        action_set = valset.action_set
        actor_set = valset.actor_set
    video_val_loader = DataLoader(video_valset,
                                  batch_size=opt.batch_size,
                                  drop_last=False,
                                  num_workers=2,
                                  shuffle=False)

    # ********************************************************************
    # ******************** Create the Environment ************************
    # ********************************************************************
    # calculate the number of classes
    if opt.model_is_actor:
        num_class = len(action_set) * len(actor_set)
    else:
        num_class = len(action_set)

    print('create the environment')
    # build the training environment
    environ = Classifier_Environ(opt.n_channels, num_class, opt.log_dir,
                                 opt.checkpoint_dir, 1, opt.ndf)
    iter, _ = environ.load('latest', path=opt.ckpt_path)
    print('using iter %d' % iter)

    environ.eval()
    loss, gt_cat, pred_dist = [], [], []
    for idx, batch in enumerate(video_val_loader):
        if opt.data_is_actor and opt.model_is_actor:
            batch['categories'] = batch['categories'] + batch['actors'] * len(
                action_set)
        environ.set_inputs(batch)
        tmp_loss, tmp_gt, tmp_pred = environ.val()
        loss.append(tmp_loss)
        gt_cat.append(tmp_gt)
        pred_dist.append(tmp_pred)

    # get accuracy,intra_E, inter_E, class_intra_E
    pred_dist = np.concatenate(pred_dist, axis=0)
    pred_cat = np.argmax(pred_dist, axis=-1)
    gt_cat = np.concatenate(gt_cat)
    if opt.model_is_actor:
        pred_cat = pred_cat % len(action_set)
    if opt.data_is_actor:
        gt_cat = gt_cat % len(action_set)
    I_score, intra_E, inter_E, class_intra_E = quant(pred_dist, action_set)
    acc = float(len(gt_cat) -
                np.count_nonzero(pred_cat != gt_cat)) / len(gt_cat) * 100

    print('acc: %.3f%%, I_score: %.3f, intra_E: %.3f, inter_E: %.3f' %
          (acc, I_score, intra_E, inter_E))
    print('class intra-E', class_intra_E)
    print(action_set)

    metrics = {
        'acc': acc,
        'I_score': I_score,
        'intra_E': intra_E,
        'inter_E': inter_E
    }
    return metrics
예제 #7
0
def main():
    torch.set_num_threads(1)
    # ********************************************************************
    # ****************** create folders and print options ****************
    # ********************************************************************
    opt, gen_args, dis_args, loss_weights = TrainOptions().parse()
    makedir(opt.checkpoint_dir)
    makedir(opt.log_dir)
    makedir(opt.output_dir)
    listopt(opt)
    with open(os.path.join(opt.log_dir, 'train_opt.txt'), 'w+') as f:
        listopt(opt, f)

    # ********************************************************************
    # ******************** Prepare the dataloaders ***********************
    # ********************************************************************
    image_transforms = transforms.Compose([
        transforms.ToTensor(),
        lambda x: x[:opt.n_channels, ::],
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

    if opt.dataset == 'Weizmann':
        trainset = WeizmannDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, opt.crop,
                                   'Train', mini_clip=opt.miniclip)
        valset = WeizmannDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, False,
                                       'Test', mini_clip=opt.miniclip)
    elif opt.dataset == 'MUG':
        trainset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, opt.crop,
                              'Train')
        valset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, False, 'Test')
    elif opt.dataset == 'SynAction':
        trainset = SynActionDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.crop, 'Train')
        valset = SynActionDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, False, 'Test')
    else:
        raise NotImplementedError('%s dataset is not supported' % opt.dataset)

    # get the validate dataloader
    video_trainset = VideoDataset(trainset, opt.video_length, every_nth=opt.every_nth, transform=video_transforms)
    video_train_loader = DataLoader(video_trainset, batch_size=opt.batch_size, drop_last=True, num_workers=2, shuffle=True)

    video_valset = VideoDataset(valset, opt.video_length, every_nth=opt.every_nth, transform=video_transforms)
    video_val_loader = DataLoader(video_valset, batch_size=opt.batch_size, drop_last=True, num_workers=2, shuffle=False)

    # ********************************************************************
    # ********************Create the environment *************************
    # ********************************************************************
    gen_args['num_categories'] = len(trainset.action_set)
    dis_args['num_categories'] = len(trainset.action_set)

    if opt.model == 'SGVAN':
        environ = SGVAN(gen_args, opt.checkpoint_dir, opt.log_dir, opt.output_dir, opt.video_length,
                        trainset.action_set, trainset.actor_set, is_eval=False, dis_args=dis_args,
                        loss_weights=loss_weights, pretrain_iters=opt.pretrain_iters)
    elif opt.model == 'TwoStreamVAN':
        environ = TwoStreamVAN(gen_args, opt.checkpoint_dir, opt.log_dir, opt.output_dir, opt.video_length,
                               trainset.action_set, trainset.actor_set, is_eval=False, dis_args=dis_args,
                               loss_weights=loss_weights, pretrain_iters=opt.pretrain_iters)
    else:
        raise ValueError('Model %s is not implemented' % opt.mode)

    current_iter = 0
    if opt.resume:
        current_iter = environ.load(opt.which_iter)
    else:
        environ.weight_init()
    environ.train()

    # ********************************************************************
    # ******************** Set the training ratio ************************
    # ********************************************************************
    # content vs motion
    cont_scheduler = Scheduler(opt.cont_ratio_start, opt.cont_ratio_end,
                               opt.cont_ratio_iter_start + opt.pretrain_iters, opt.cont_ratio_iter_end + opt.pretrain_iters,
                               mode='linear')
    # easier vs harder motion
    m_img_scheduler = Scheduler(opt.motion_ratio_start, opt.motion_ratio_end,
                                opt.motion_ratio_iter_start + opt.pretrain_iters, opt.motion_ratio_iter_end + opt.pretrain_iters,
                                mode='linear')

    # ********************************************************************
    # ***************************  Training  *****************************
    # ********************************************************************
    recons_c, pred_c, vid_c = 0, 0, 0
    video_enumerator = enumerate(video_train_loader)
    while current_iter < opt.total_iters:
        start_time = time.time()
        current_iter += 1
        batch_idx, batch = next(video_enumerator)
        environ.set_inputs(batch)

        if current_iter < opt.pretrain_iters:
            # ********************** Pre-train the Content Stream **************
            environ.optimize_recons_pretrain_parameters(current_iter)

            # print loss to the screen and save intermediate results to tensorboard
            if current_iter % opt.print_freq == 0:
                environ.print_loss(current_iter, start_time)
                environ.visual_batch(current_iter, name='%s_current_batch' % environ.task)

            # validation
            if current_iter % opt.val_freq == 0:
                environ.eval()
                # validation of the content generation
                for idx, batch in enumerate(video_val_loader):
                    environ.set_inputs(batch)
                    environ.reconstruct_forward(ae_mode='mean', is_eval=True)
                    if idx == 0:
                        environ.visual_batch(current_iter, name='val_recons')
                # save the current checkpoint
                environ.save('latest', current_iter)
                environ.train()
        else:
            # ********************* Jointly train the Content & Motion *************
            ep1 = cont_scheduler.get_value(current_iter)
            ep2 = m_img_scheduler.get_value(current_iter)
            recons = (random.random() > ep1)
            img_level = (random.random() > ep2)
            if recons:
                # content training
                recons_c += 1
                environ.optimize_recons_parameters(current_iter)
            else:
                if img_level:
                    # easier motion training
                    pred_c += 1
                    environ.optimize_pred_parameters()
                else:
                    # harder motion training
                    vid_c += 1
                    environ.optimize_vid_parameters(current_iter)

            # print loss to the screen and save intermediate results to tensorboard
            if current_iter % opt.print_freq == 0:
                environ.print_loss(current_iter, start_time)
                environ.visual_batch(current_iter, name='%s_current_batch' % environ.task)
                print('recons: %d, pred: %d, vid: %d' % (recons_c, pred_c, vid_c))
                recons_c, pred_c, vid_c = 0, 0, 0

            # validation and save checkpoint
            if current_iter % opt.val_freq == 0:
                environ.eval()
                for idx, batch in enumerate(video_val_loader):
                    environ.set_inputs(batch)

                    # content stream validation
                    environ.reconstruct_forward(ae_mode='mean', is_eval=True)
                    if idx == 0:
                        environ.visual_batch(current_iter, name='val_recons')

                    # easier motion stream validation
                    environ.predict_forward(ae_mode='mean', is_eval=True)
                    if idx == 0:
                        environ.visual_batch(current_iter, name='val_pred')

                    # harder motion stream validation
                    environ.video_forward(eplison=0, ae_mode='mean', is_eval=True)
                    if idx == 0:
                        environ.visual_batch(current_iter, name='val_video')

                # generate videos for different class
                for idx, cls_name in enumerate(valset.action_set):
                    environ.get_category(cls_id=idx)
                    output_dir = environ.full_test(idx, 32, 10, current_iter, cls_name)
                metrics = eval(opt, output_dir)
                environ.print_loss(current_iter, start_time, metrics)

                # remove the generated video
                rm_cmd = 'rm -r %s' % output_dir
                os.system(rm_cmd)

                # save the latest checkpoint
                environ.save('latest', current_iter)
                environ.train()

        # save the checkpoint
        if current_iter % opt.save_freq == 0:
            environ.save(current_iter, current_iter)

        # get a new enumerator
        if batch_idx == len(video_train_loader) - 1:
            video_enumerator = enumerate(video_train_loader)
예제 #8
0
def main():
    # ************************************************************
    # ************** create folders and print options ************
    # ************************************************************
    opt = TrainOptions().parse()
    listopt(opt)
    print('create the directories')
    makedir(opt.checkpoint_dir)
    makedir(opt.log_dir)
    with open(os.path.join(opt.log_dir, 'train_options.txt'), 'w+') as f:
        listopt(opt, f)

    # ********************************************************************
    # ******************** Prepare the dataloaders ***********************
    # ********************************************************************
    print('define image and video transformation')
    # get the dataloader
    image_transforms = transforms.Compose([
        transforms.ToTensor(),
        lambda x: x[:opt.n_channels, ::],
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

    print('create the video dataloader')
    if opt.dataset == 'Weizmann':
        trainset = WeizmannDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, opt.crop, 'Train')
        valset = WeizmannDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, False, 'Test')
    elif opt.dataset == 'MUG':
        trainset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth,
                                   opt.crop, 'Train')
        valset = MUGDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, False,
                                 'Test')
    elif opt.dataset == 'MUG2':
        trainset = MUGDataset_2(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth,
                                   opt.crop, 'Train')
        print(trainset.action_set)
        valset = MUGDataset_2(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.every_nth, False,
                                 'Test')
        print(valset.action_set)
    elif opt.dataset == 'SynAction':
        trainset = SynActionDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, opt.crop, 'Train')
        valset = SynActionDataset(opt.dataroot, opt.textroot, opt.video_length, opt.image_size, False, 'Test')
    else:
        raise NotImplementedError('%s is not implemented' % opt.dataset)
    # get the validate dataloader
    video_trainset = VideoDataset(trainset, opt.video_length, every_nth=opt.every_nth, transform=video_transforms)
    video_train_loader = DataLoader(video_trainset, batch_size=opt.batch_size, drop_last=True, num_workers=2, shuffle=True)

    video_valset = VideoDataset(valset, opt.video_length, every_nth=opt.every_nth, transform=video_transforms)
    video_val_loader = DataLoader(video_valset, batch_size=opt.batch_size, drop_last=False, num_workers=2, shuffle=False)

    # ********************************************************************
    # ******************** Create the Environment ************************
    # ********************************************************************
    # calculate the number of classes
    if opt.model_is_actor:
        num_class = len(trainset.action_set) * len(trainset.actor_set)
    else:
        num_class = len(trainset.action_set)

    print('create the environment')
    environ = Classifier_Environ(opt.n_channels, num_class, opt.log_dir, opt.checkpoint_dir, opt.lr, opt.ndf)
    current_iter = 0

    # load the checkpoints
    if opt.resume:
        current_iter, opt.lr = environ.load(opt.which_iter)
    else:
        environ.weight_init()

    # ********************************************************************
    # ***************************  Training  *****************************
    # ********************************************************************
    action_set = trainset.action_set
    max_acc, max_I_score, best_iter = 0, 0, 0
    print('begin training')
    video_enumerator = enumerate(video_train_loader)
    while current_iter < opt.total_iters:
        start_time = time.time()
        environ.train()
        current_iter += 1

        batch_idx, batch = next(video_enumerator)
        # modify the gt category if the model needs to distinguish actors
        if opt.data_is_actor:
            batch['categories'] = batch['categories'] + batch['actors'] * len(action_set)
        environ.set_inputs(batch)

        environ.optimize()

        # print losses
        if current_iter % opt.print_freq == 0:
            environ.print_loss(current_iter, start_time)

        # validation
        if current_iter % opt.val_freq == 0:
            environ.eval()
            loss, gt_cat, pred_dist = [], [], []

            # go through the validation set
            for idx, batch in enumerate(video_val_loader):
                environ.set_inputs(batch)
                if opt.data_is_actor:
                    batch['categories'] = batch['categories'] + batch['actors'] * len(action_set)
                tmp_loss, tmp_gt, tmp_pred = environ.val()
                loss.append(tmp_loss)
                gt_cat.append(tmp_gt)
                pred_dist.append(tmp_pred)

            pred_dist = np.concatenate(pred_dist, axis=0)
            pred_cat = np.argmax(pred_dist, axis=-1)
            gt_cat = np.concatenate(gt_cat)

            # calculate the metrics
            I_score, intra_E, inter_E, class_intra_E = quant(pred_dist, trainset.action_set)
            acc = float(len(gt_cat) - np.count_nonzero(pred_cat != gt_cat))/len(gt_cat) * 100
            loss = {'val_loss': np.mean(loss), 'acc': acc, 'I_score': I_score, 'intra_E': intra_E, 'inter_E': inter_E}
            environ.print_loss(current_iter, start_time, loss=loss)

            # save the checkpoint if the current gives the best performance
            if acc >= max_acc and I_score >= max_I_score:
                max_acc = acc
                max_I_score = I_score
                environ.save('best', current_iter)
                best_iter = current_iter
            print('max_I_score: %.3f, max_acc: %.3f, best iter: %d' % (max_I_score, max_acc, best_iter))
            environ.save(current_iter, current_iter)

        # save the current iteration
        if current_iter % opt.save_freq == 0:
            environ.save('latest', current_iter)

        # adjust the learning rate
        if batch_idx == len(video_train_loader) - 1:
            video_enumerator = enumerate(video_train_loader)
            opt.lr = opt.lr * opt.decay
            environ.adjust_learning_rate(opt.lr)
예제 #9
0
import subprocess

from utils import util

config = util.Config('config.json')
sl = util.check_system()['sl']
video = util.VideoParams(config)

config.duration = '60'
original = f'..{sl}original'
yuv_forders_10s = f'..{sl}yuv-10s'
yuv_forders_60s = f'..{sl}yuv-full'
util.makedir(f'{yuv_forders_10s}')
util.makedir(f'{yuv_forders_60s}')
scale = config.scale
fps = config.fps

for name in config.videos_list:
    start_time = config.videos_list[name]['time']

    out_name = f'{name}_{scale}_{fps}.yuv'
    in_name = f'{original}{sl}{name}.mp4'

    par_in = f'-y -hide_banner -v quiet -ss {start_time} -i {in_name}'

    par_out_10s = f'-t 10 -r {fps} -vf scale={scale} -map 0:0 ..{sl}yuv-10s{sl}{out_name}'
    command = f'ffmpeg {par_in} {par_out_10s}'
    print(command)
    subprocess.run(command, shell=True, stderr=subprocess.STDOUT)

    par_out_60s = f'-t 60 -r {fps} -vf scale={scale} -map 0:0 ..{sl}yuv-full{sl}{out_name}'
예제 #10
0
import os
import glob
from utils.util import makedir
import pdb

phrase_list = [
    'clapping', 'waving', 'pick fruit', 'kick soccerball', 'wheelbarrow',
    'stall soccerball', 'baseball step', 'rifle side', 'kneeing',
    'jazz dancing', 'walking', 'running', 'jump', 'kick', 'hook',
    'cheering while', 'goalkeeper catch', 'throw', 'climb'
]

root = '/Users/sunxm/Downloads/mixamo/'
files = glob.glob(os.path.join(root, '*'))
output_path = '/Users/sunxm/Documents/mixamo/rough_pick/'
makedir(output_path)
for file_path in files:
    file_name = file_path.split('/')[-1]
    for phrase in phrase_list:
        words = phrase.split()
        flag = True
        for word in words:
            if word not in file_name.lower():
                flag = False
                break
        if flag:
            # pdb.set_trace()
            # tokens = file_path.split()

            file_path = file_path.replace(' ', '\ ')
            file_path = file_path.replace('(', '\(')