示例#1
0
def get_transform(mode, args):
    seq_len = args.seq_len * 2 # for both rgb and flow

    null_transform = transforms.Compose([
        A.RandomSizedCrop(size=args.img_dim, consistent=False, seq_len=seq_len, bottom_area=0.2),
        A.RandomHorizontalFlip(consistent=False, seq_len=seq_len),
        A.ToTensor(),
    ])

    base_transform = transforms.Compose([
        A.RandomSizedCrop(size=args.img_dim, consistent=False, seq_len=seq_len, bottom_area=0.2),
        transforms.RandomApply([
            A.ColorJitter(0.4, 0.4, 0.4, 0.1, p=1.0, consistent=False, seq_len=seq_len)
            ], p=0.8),
        A.RandomGray(p=0.2, seq_len=seq_len),
        transforms.RandomApply([A.GaussianBlur([.1, 2.], seq_len=seq_len)], p=0.5),
        A.RandomHorizontalFlip(consistent=False, seq_len=seq_len),
        A.ToTensor(),
    ])

    # oneclip: temporally take one clip, random augment twice
    # twoclip: temporally take two clips, random augment for each
    # merge oneclip & twoclip transforms with 50%/50% probability
    transform = A.TransformController(
                    [A.TwoClipTransform(base_transform, null_transform, seq_len=seq_len, p=0.3),
                     A.OneClipTransform(base_transform, null_transform, seq_len=seq_len)],
                    weights=[0.5,0.5])
    print(transform)

    return transform 
示例#2
0
def get_transform(mode, args):
    seq_len = args.seq_len * 2  # for both rgb and flow

    null_transform = transforms.Compose([
        A.RandomSizedCrop(size=args.img_dim,
                          consistent=False,
                          seq_len=seq_len,
                          bottom_area=0.2),
        A.RandomHorizontalFlip(consistent=False, seq_len=seq_len),
        A.ToTensor(),
    ])

    base_transform = transforms.Compose([
        A.RandomSizedCrop(size=args.img_dim,
                          consistent=False,
                          seq_len=seq_len,
                          bottom_area=0.2),
        transforms.RandomApply([
            A.ColorJitter(
                0.4, 0.4, 0.4, 0.1, p=1.0, consistent=False, seq_len=seq_len)
        ],
                               p=0.8),
        A.RandomGray(p=0.2, seq_len=seq_len),
        transforms.RandomApply([A.GaussianBlur([.1, 2.], seq_len=seq_len)],
                               p=0.5),
        A.RandomHorizontalFlip(consistent=False, seq_len=seq_len),
        A.ToTensor(),
    ])
    transform = A.TransformController([
        A.TwoClipTransform(
            base_transform, null_transform, seq_len=seq_len, p=0.3),
        A.OneClipTransform(base_transform, null_transform, seq_len=seq_len)
    ],
                                      weights=[0.5, 0.5])
    return transform
示例#3
0
def get_data(args,
             mode='train',
             return_label=False,
             hierarchical_label=False,
             action_level_gt=False,
             num_workers=0,
             path_dataset='',
             path_data_info=''):
    if hierarchical_label and args.dataset not in ['finegym', 'hollywood2']:
        raise Exception(
            'Hierarchical information is only implemented in finegym and hollywood2 datasets'
        )
    if return_label and not action_level_gt and args.dataset != 'finegym':
        raise Exception(
            'subaction only subactions available in finegym dataset')

    if mode == 'train':
        transform = transforms.Compose([
            augmentation.RandomSizedCrop(size=args.img_dim,
                                         consistent=True,
                                         p=1.0),
            augmentation.RandomHorizontalFlip(consistent=True),
            augmentation.RandomGray(consistent=False, p=0.5),
            augmentation.ColorJitter(brightness=0.5,
                                     contrast=0.5,
                                     saturation=0.5,
                                     hue=0.25,
                                     p=1.0),
            augmentation.ToTensor(),
            augmentation.Normalize()
        ])
    else:
        transform = transforms.Compose([
            augmentation.CenterCrop(size=args.img_dim, consistent=True),
            augmentation.ToTensor(),
            augmentation.Normalize()
        ])

    if args.dataset == 'kinetics':
        dataset = Kinetics600(mode=mode,
                              transform=transform,
                              seq_len=args.seq_len,
                              num_seq=args.num_seq,
                              downsample=5,
                              return_label=return_label,
                              return_idx=False,
                              path_dataset=path_dataset,
                              path_data_info=path_data_info)
    elif args.dataset == 'hollywood2':
        if return_label:
            assert action_level_gt, 'hollywood2 does not have subaction labels'
        dataset = Hollywood2(mode=mode,
                             transform=transform,
                             seq_len=args.seq_len,
                             num_seq=args.num_seq,
                             downsample=args.ds,
                             return_label=return_label,
                             hierarchical_label=hierarchical_label,
                             path_dataset=path_dataset,
                             path_data_info=path_data_info)
    elif args.dataset == 'finegym':
        if hierarchical_label:
            assert not action_level_gt, 'finegym does not have hierarchical information at the action level'
        dataset = FineGym(
            mode=mode,
            transform=transform,
            seq_len=args.seq_len,
            num_seq=args.num_seq,
            fps=int(25 / args.ds),  # approx
            return_label=return_label,
            hierarchical_label=hierarchical_label,
            action_level_gt=action_level_gt,
            path_dataset=path_dataset,
            return_idx=False,
            path_data_info=path_data_info)
    elif args.dataset == 'movienet':
        assert not return_label, 'Not yet implemented (actions not available online)'
        assert args.seq_len == 3, 'We only have 3 frames per subclip/scene, but always 3'
        dataset = MovieNet(mode=mode,
                           transform=transform,
                           num_seq=args.num_seq,
                           path_dataset=path_dataset,
                           path_data_info=path_data_info)
    else:
        raise ValueError('dataset not supported')

    sampler = data.RandomSampler(
        dataset) if mode == 'train' else data.SequentialSampler(dataset)

    data_loader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=(mode != 'test'
                   )  # test always same examples independently of batch size
    )
    return data_loader
示例#4
0
def main(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device('cuda')
    num_gpu = len(str(args.gpu).split(','))
    args.batch_size = num_gpu * args.batch_size

    ### model ###
    if args.model == 'memdpc':
        model = MemDPC_BD(sample_size=args.img_dim,
                          num_seq=args.num_seq,
                          seq_len=args.seq_len,
                          network=args.net,
                          pred_step=args.pred_step,
                          mem_size=args.mem_size)
    else:
        raise NotImplementedError('wrong model!')

    model.to(device)
    model = nn.DataParallel(model)
    model_without_dp = model.module

    ### optimizer ###
    params = model.parameters()
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    criterion = nn.CrossEntropyLoss()

    ### data ###
    transform = transforms.Compose([
        A.RandomSizedCrop(size=224, consistent=True,
                          p=1.0),  # crop from 256 to 224
        A.Scale(size=(args.img_dim, args.img_dim)),
        A.RandomHorizontalFlip(consistent=True),
        A.RandomGray(consistent=False, p=0.25),
        A.ColorJitter(0.5, 0.5, 0.5, 0.25, consistent=False, p=1.0),
        A.ToTensor(),
        A.Normalize()
    ])

    train_loader = get_data(transform, 'train')
    val_loader = get_data(transform, 'val')

    if 'ucf' in args.dataset:
        lr_milestones_eps = [300, 400]
    elif 'k400' in args.dataset:
        lr_milestones_eps = [120, 160]
    else:
        lr_milestones_eps = [1000]  # NEVER
    lr_milestones = [len(train_loader) * m for m in lr_milestones_eps]
    print('=> Use lr_scheduler: %s eps == %s iters' %
          (str(lr_milestones_eps), str(lr_milestones)))
    lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(
        ep, gamma=0.1, step=lr_milestones, repeat=1)
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    best_acc = 0
    args.iteration = 1

    ### restart training ###
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading resumed checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            args.iteration = checkpoint['iteration']
            best_acc = checkpoint['best_acc']
            model_without_dp.load_state_dict(checkpoint['state_dict'])
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
            except:
                print('[WARNING] Not loading optimizer states')
            print("=> loaded resumed checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("[Warning] no checkpoint found at '{}'".format(args.resume))
            sys.exit(0)

    # logging tools
    args.img_path, args.model_path = set_path(args)
    args.logger = Logger(path=args.img_path)
    args.logger.log('args=\n\t\t' + '\n\t\t'.join(
        ['%s:%s' % (str(k), str(v)) for k, v in vars(args).items()]))

    args.writer_val = SummaryWriter(logdir=os.path.join(args.img_path, 'val'))
    args.writer_train = SummaryWriter(
        logdir=os.path.join(args.img_path, 'train'))

    torch.backends.cudnn.benchmark = True

    ### main loop ###
    for epoch in range(args.start_epoch, args.epochs):
        np.random.seed(epoch)
        random.seed(epoch)

        train_loss, train_acc = train_one_epoch(train_loader, model, criterion,
                                                optimizer, lr_scheduler,
                                                device, epoch, args)
        val_loss, val_acc = validate(val_loader, model, criterion, device,
                                     epoch, args)

        # save check_point
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_dict = {
            'epoch': epoch,
            'state_dict': model_without_dp.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'iteration': args.iteration
        }
        save_checkpoint(save_dict,
                        is_best,
                        filename=os.path.join(args.model_path,
                                              'epoch%s.pth.tar' % str(epoch)),
                        keep_all=False)

    print('Training from ep %d to ep %d finished' %
          (args.start_epoch, args.epochs))
    sys.exit(0)