Example #1
0
def main():
    global best_loss
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    args.kldv_coef = 1
    args.long_coef = 1

    args.frame_transforms = 'crop'
    args.frame_aug = 'grid'
    args.npatch = 49
    args.img_size = 256
    args.pstride = [0.5, 0.5]
    args.patch_size = [64, 64, 3]

    args.visualize = False

    model = tc.TimeCycle(args, vis=vis).cuda()

    params['mapScale'] = model(torch.zeros(1, 10, 3, 320, 320).cuda(),
                               just_feats=True)[1].shape[-2:]
    params['mapScale'] = 320 // np.array(params['mapScale'])

    val_loader = torch.utils.data.DataLoader(
        davis.DavisSet(params, is_train=False) if not 'jhmdb' in args.filelist  else \
            jhmdb.JhmdbSet(params, is_train=False),
        batch_size=int(params['batchSize']), shuffle=False,
        num_workers=args.workers, pin_memory=True)

    cudnn.benchmark = False
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Load checkpoint.
    if os.path.isfile(args.resume):
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.resume)

        utils.partial_load(checkpoint['model'], model, skip_keys=['head'])

        del checkpoint

    model.eval()
    # model = torch.nn.DataParallel(model).cuda()    #     model = model.cuda()
    model = model.cuda()

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    print('\Testing')
    # with torch.no_grad():
    test_loss = test(val_loader, model, 1, use_cuda, args)
def main(args, vis):
    gpus_count = torch.cuda.device_count()
    print('Available CUDA devices: ', gpus_count)
    print('Current CUDA device: ', torch.cuda.current_device())

    model = CRW(args, vis=vis).to(args.device)
    args.mapScale = test_utils.infer_downscale(model)

    args.use_lab = args.model_type == 'uvc'
    dataset = (vos.VOSDataset
               if not 'jhmdb' in args.filelist else jhmdb.JhmdbSet)(args)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=int(args.batchSize),
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # cudnn.benchmark = False
    print('Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Load checkpoint.
    if os.path.isfile(args.resume):
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.resume)

        if args.model_type == 'scratch':
            state = {}
            for k, v in checkpoint['model'].items():
                if 'conv1.1.weight' in k or 'conv2.1.weight' in k:
                    state[k.replace('.1.weight', '.weight')] = v
                else:
                    state[k] = v
            utils.partial_load(state, model, skip_keys=['head'])
        else:
            utils.partial_load(checkpoint['model'], model, skip_keys=['head'])

        del checkpoint

    model.eval()
    model = model.to(args.device)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    with torch.no_grad():
        test_loss = test(val_loader, model, args)
Example #3
0
def main():
    global best_loss
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    model = tc.TimeCycle(args).cuda()
    model = Wrap(model)

    params['mapScale'] = model(torch.zeros(1, 10, 3, 320, 320).cuda(),
                               None,
                               True,
                               func='forward')[1].shape[-2:]
    params['mapScale'] = 320 // np.array(params['mapScale'])

    val_loader = torch.utils.data.DataLoader(davis.DavisSet(params,
                                                            is_train=False),
                                             batch_size=int(
                                                 params['batchSize']),
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    cudnn.benchmark = False
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Load checkpoint.
    if os.path.isfile(args.resume):
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(args.resume)
        # model.model.load_state_dict(checkpoint['model'])
        utils.partial_load(checkpoint['model'], model.model)

        del checkpoint

    model.eval()
    model = torch.nn.DataParallel(model).cuda()  #     model = model.cuda()

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    print('\Testing')
    with torch.no_grad():
        test_loss = test(val_loader, model, 1, use_cuda)
Example #4
0
def main(args):
    print(args)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    print("Preparing training dataloader")
    traindir = os.path.join(args.data_path,
                            'train_256' if not args.fast_test else 'val_256')
    valdir = os.path.join(args.data_path, 'val_256')

    st = time.time()
    cache_path = _get_cache_path(traindir)

    transform_train = utils.augs.get_train_transforms(args)

    def make_dataset(is_train, cached=None):
        _transform = transform_train if is_train else transform_test

        if 'kinetics' in args.data_path.lower():
            return Kinetics400(
                traindir if is_train else valdir,
                frames_per_clip=args.clip_len,
                step_between_clips=1,
                transform=transform_train,
                extensions=('mp4'),
                frame_rate=args.frame_skip,
                # cached=cached,
                _precomputed_metadata=cached)
        elif os.path.isdir(
                args.data_path
        ):  # HACK assume image dataset if data path is a directory
            return torchvision.datasets.ImageFolder(root=args.data_path,
                                                    transform=_transform)
        else:
            return VideoList(
                filelist=args.data_path,
                clip_len=args.clip_len,
                is_train=is_train,
                frame_gap=args.frame_skip,
                transform=_transform,
                random_clip=True,
            )

    if args.cache_dataset and os.path.exists(cache_path):
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
        cached = dict(video_paths=dataset.video_clips.video_paths,
                      video_fps=dataset.video_clips.video_fps,
                      video_pts=dataset.video_clips.video_pts)
        dataset = make_dataset(is_train=True, cached=cached)
        dataset.transform = transform_train
    else:
        dataset = make_dataset(is_train=True)
        if args.cache_dataset and 'kinetics' in args.data_path.lower():
            print("Saving dataset_train to {}".format(cache_path))
            utils.mkdir(os.path.dirname(cache_path))
            dataset.transform = None
            torch.save((dataset, traindir), cache_path)

    if hasattr(dataset, 'video_clips'):
        dataset.video_clips.compute_clips(args.clip_len,
                                          1,
                                          frame_rate=args.frame_skip)

    print("Took", time.time() - st)

    def make_data_sampler(is_train, dataset):
        torch.manual_seed(0)
        if hasattr(dataset, 'video_clips'):
            _sampler = RandomClipSampler  #UniformClipSampler
            return _sampler(dataset.video_clips, args.clips_per_video)
        else:
            return torch.utils.data.sampler.RandomSampler(
                dataset) if is_train else None

    print("Creating data loaders")
    train_sampler = make_data_sampler(True, dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,  # shuffle=not args.fast_test,
        sampler=train_sampler,
        num_workers=args.workers // 2,
        pin_memory=True,
        collate_fn=collate_fn)

    vis = utils.visualize.Visualize(args) if args.visualize else None

    print("Creating model")
    model = CRW(args, vis=vis).to(device)
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    model_without_ddp = model
    if args.data_parallel:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    if args.partial_reload:
        checkpoint = torch.load(args.partial_reload, map_location='cpu')
        utils.partial_load(checkpoint['model'], model_without_ddp)
        optimizer.param_groups[0]["lr"] = args.lr
        args.start_epoch = checkpoint['epoch'] + 1

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    def save_model_checkpoint():
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            torch.save(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            torch.save(checkpoint,
                       os.path.join(args.output_dir, 'checkpoint.pth'))

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        train_one_epoch(model,
                        optimizer,
                        lr_scheduler,
                        data_loader,
                        device,
                        epoch,
                        args.print_freq,
                        vis=vis,
                        checkpoint_fn=save_model_checkpoint)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main(args):

    # Eager Checks
    if args.teacher_student:
        assert args.prob == 1, "Teacher-Student training is not yet compatible with probabistic sp | patch sampling"

    print("Arguments", end="\n" + "-" * 100 + "\n")
    for arg, value in vars(args).items():
        print(f"{arg} = {value}")
    print("-" * 100)
    print("torch version: ", torch.__version__)
    print("torchvision version: ", torchvision.__version__)

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    print("Preparing training dataloader", end="\n" + "-" * 100 + "\n")
    traindir = os.path.join(args.data_path,
                            'train_256' if not args.fast_test else 'val_256')
    valdir = os.path.join(args.data_path, 'val_256')

    st = time.time()
    cache_path = args.cache_path

    transform_train = utils.augs.get_train_transforms(args)

    # Dataset
    def make_dataset(is_train, cached=None):
        _transform = transform_train if is_train else transform_test

        if 'kinetics' in args.data_path.lower():
            return Kinetics400(
                traindir if is_train else valdir,
                frames_per_clip=args.clip_len,
                step_between_clips=1,
                transform=transform_train,
                extensions=('mp4'),
                frame_rate=args.frame_skip,
                # cached=cached,
                _precomputed_metadata=cached,
                sp_method=args.sp_method,
                num_components=args.num_sp,
                prob=args.prob,
                randomise_superpixels=args.randomise_superpixels,
                randomise_superpixels_range=args.randomise_superpixels_range)
        # HACK assume image dataset if data path is a directory
        elif os.path.isdir(args.data_path):
            return torchvision.datasets.ImageFolder(root=args.data_path,
                                                    transform=_transform)
        else:
            return VideoList(
                filelist=args.data_path,
                clip_len=args.clip_len,
                is_train=is_train,
                frame_gap=args.frame_skip,
                transform=_transform,
                random_clip=True,
            )

    if args.cache_dataset and os.path.exists(cache_path):
        print(f"Loading dataset_train from {cache_path}",
              end="\n" + "-" * 100 + "\n")
        dataset, _ = torch.load(cache_path)
        cached = dict(video_paths=dataset.video_clips.video_paths,
                      video_fps=dataset.video_clips.video_fps,
                      video_pts=dataset.video_clips.video_pts)
        dataset = make_dataset(is_train=True, cached=cached)
        dataset.transform = transform_train
    else:
        dataset = make_dataset(is_train=True)
        if 'kinetics' in args.data_path.lower():  # args.cache_dataset and
            print(f"Saving dataset_train to {cache_path}",
                  end="\n" + "-" * 100 + "\n")
            utils.mkdir(os.path.dirname(cache_path))
            dataset.transform = None
            torch.save((dataset, traindir), cache_path)
            dataset.transform = transform_train

    if hasattr(dataset, 'video_clips'):
        dataset.video_clips.compute_clips(args.clip_len,
                                          1,
                                          frame_rate=args.frame_skip)

    print("Took", time.time() - st)

    # Data Loader
    def make_data_sampler(is_train, dataset):
        torch.manual_seed(0)
        if hasattr(dataset, 'video_clips'):
            _sampler = RandomClipSampler  # UniformClipSampler
            return _sampler(dataset.video_clips, args.clips_per_video)
        else:
            return torch.utils.data.sampler.RandomSampler(
                dataset) if is_train else None

    print("Creating data loaders", end="\n" + "-" * 100 + "\n")
    train_sampler = make_data_sampler(True, dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers // 2,
        pin_memory=True,
        collate_fn=collate_fn,
        #   shuffle=not args.fast_test,
    )

    print("Set Compactness at:", args.compactness)
    data_loader.dataset.set_compactness(args.compactness)

    # Visualisation
    vis = utils.visualize.Visualize(args) if args.visualize else None

    # Model
    print("Creating model", end="\n" + "-" * 100 + "\n")
    if not args.teacher_student:
        model = CRW(args, vis=vis).to(device)
    else:
        model = CRWTeacherStudent(args, vis=None).to(
            device)  # NOTE Disabled vis during prototyping
    # print(model)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Learning rate schedule
    lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

    model_without_ddp = model

    # Parallelise model over GPUs
    if args.data_parallel:
        model = torch.nn.parallel.DataParallel(model)
        model_without_ddp = model.module

    # Partially load weights from model checkpoint
    if args.partial_reload:
        checkpoint = torch.load(args.partial_reload, map_location='cpu')
        utils.partial_load(checkpoint['model'], model_without_ddp)
        optimizer.param_groups[0]["lr"] = args.lr
        # args.start_epoch = checkpoint['epoch'] + 1

    # Resume from checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    def save_model_checkpoint():
        if args.output_dir:
            checkpoint = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }
            torch.save(checkpoint,
                       os.path.join(args.output_dir, f'model_{epoch}.pth'))
            torch.save(checkpoint,
                       os.path.join(args.output_dir, 'checkpoint.pth'))

    # Start Training
    print("Start training", end="\n" + "-" * 100 + "\n")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        train_one_epoch(model,
                        optimizer,
                        lr_scheduler,
                        data_loader,
                        device,
                        epoch,
                        args.print_freq,
                        vis=vis,
                        checkpoint_fn=save_model_checkpoint,
                        prob=args.prob)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f'Training time {total_time_str}')