def _get_sampler(self, epoch) -> "DistributedSampler":
     if self.split == "train":
         # For video model training, we don't necessarily want to use all possible
         # clips in the video in one training epoch. More often, we randomly
         # sample at most N clips per training video. In practice, N is often 1
         clip_sampler = RandomClipSampler(self.video_clips,
                                          self.clips_per_video)
     else:
         # For video model testing, we sample N evenly spaced clips per test
         # video. We will simply average predictions over them
         clip_sampler = UniformClipSampler(self.video_clips,
                                           self.clips_per_video)
     clip_sampler = MaxLengthClipSampler(clip_sampler,
                                         num_samples=self.num_samples)
     world_size = get_world_size()
     rank = get_rank()
     sampler = DistributedSampler(
         clip_sampler,
         num_replicas=world_size,
         rank=rank,
         shuffle=self.shuffle,
         group_size=self.clips_per_video,
     )
     sampler.set_epoch(epoch)
     return sampler
예제 #2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))
    netG = moco.builder.MaskGenerator()
    netD = moco.builder.MoCo(models.__dict__[args.arch], args.moco_dim,
                             args.moco_k, args.moco_m, args.moco_t, args.mlp)
    print(netG)
    print(netD)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            netG.cuda(args.gpu)
            netD.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            netG = torch.nn.parallel.DistributedDataParallel(
                netG, device_ids=[args.gpu], find_unused_parameters=True)
            netD = torch.nn.parallel.DistributedDataParallel(
                netD, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            netG.cuda()
            netD.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            netG = torch.nn.parallel.DistributedDataParallel(netG)
            netD = torch.nn.parallel.DistributedDataParallel(netD)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        netG = netG.cuda(args.gpu)
        netD = netD.cuda(args.gpu)
        # comment out the following line for debugging
        # raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        pass  # raise NotImplementedError("Only DistributedDataParallel is supported.") for debug on cpu
    # torch.cuda.synchronize()
    optimizer_g = torch.optim.SGD(netG.parameters(),
                                  args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    optimizer_d = torch.optim.SGD(netD.parameters(),
                                  args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    G_criterion = nn.L1Loss().cuda(args.gpu)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            netD.load_state_dict(checkpoint['state_dict'])
            #optimizer_d.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

        if os.path.isfile(args.resumeG):
            print("=> loading checkpoint '{}'".format(args.resumeG))
            if args.gpu is None:
                checkpoint = torch.load(args.resumeG)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resumeG, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            netG.load_state_dict(checkpoint['state_dict'])
            #optimizer_g.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resumeG, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resumeG))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    video_augmentation = transforms.Compose([
        transforms_video.ToTensorVideo(),
        transforms_video.RandomResizedCropVideo(args.crop_size, (0.2, 1)),
    ])
    audio_augmentation = moco.loader.DummyAudioTransform()
    augmentation = {'video': video_augmentation, 'audio': audio_augmentation}

    augmentation_gpu = moco.loader.MoCoAugmentV2(
        args.crop_size) if args.aug_plus else moco.loader.MoCoAugment(
            args.crop_size)

    train_dataset = Kinetics400(traindir,
                                args.frame_per_clip,
                                args.step_between_clips,
                                extensions='mp4',
                                transform=augmentation,
                                num_workers=4)

    train_sampler = RandomClipSampler(train_dataset.video_clips, 1)

    if args.distributed:
        # train_sampler = torch.utils.data.distributed.DistributedSampler(train_sampler)
        train_sampler = DistributedSampler(train_sampler)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               multiprocessing_context="fork")
    if args.multiprocessing_distributed and args.gpu == 0:
        log_dir = "{}_bs={}_lr={}_cs={}_fpc={}".format(args.log_dir,
                                                       args.batch_size,
                                                       args.lr, args.crop_size,
                                                       args.frame_per_clip)
        writer = SummaryWriter(log_dir)
    else:
        writer = None
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer_d, epoch, args)
        adjust_learning_rate(optimizer_g, epoch, args)

        # train for one epoch
        train(train_loader, augmentation_gpu, criterion, G_criterion, netG,
              netD, optimizer_g, optimizer_d, epoch, args, writer)

        if (epoch + 1) % 10 == 0 and (not args.multiprocessing_distributed or
                                      (args.multiprocessing_distributed
                                       and args.rank % ngpus_per_node == 0)):
            ckp_dir = "{}_bs={}_lr={}_cs={}_fpc={}".format(
                args.ckp_dir, args.batch_size, args.lr, args.crop_size,
                args.frame_per_clip)
            save_checkpoint(epoch, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': netG.state_dict(),
            },
                            ckp_dir + '/netG',
                            max_save=20,
                            is_best=False)

            save_checkpoint(epoch, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': netD.state_dict(),
            },
                            ckp_dir + '/netD',
                            max_save=20,
                            is_best=False)