Example #1
0
                    logger.info('-' * 80)
                    logger.info(
                        'Buffer # {:3d} | {:5d} examples processed | data {} | Valid CE {:5.3f} |'
                        ' Valid PPL {:5.3f} | Improved {:3d} steps ago | patience {:3d}'
                        .format(epoch,
                                total_examples_processed, datasets, dev_ce,
                                np.exp(dev_ce), not_improved, patience))
                    logger.info('-' * 80)

                    logger.info('Saving current model ...')
                    save_curr_model(model, optimizer, datasets, patience)

                    model.train()

                    if patience < 0:
                        return patience, datasets

        return patience, datasets
    except KeyboardInterrupt:
        return patience, datasets


if __name__ == '__main__':
    params = parser.parse_args()
    logger = initialize_exp(params)

    # Use all available GPUs. To use fewer, use CUDA_VISIBLE_DEVICES at launch
    params.n_gpus = torch.cuda.device_count()

    main()
def main(args, writer):

    # Create Logger
    logger, training_stats = initialize_exp(args, "epoch", "loss", "prec1",
                                            "prec5", "loss_val", "prec1_val",
                                            "prec5_val")

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    logger.info("Loading model")
    model = load_model(
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        num_classes=args.num_clusters,
        norm_feat=False,
        use_mlp=args.use_mlp,
        headcount=args.headcount,
    )

    # Load model weights
    weight_path_type = type(args.weights_path)
    if weight_path_type == str:
        weight_path_not_none = args.weights_path != 'None'
    else:
        weight_path_not_none = args.weights_path is not None
    if not args.pretrained and weight_path_not_none:
        logger.info("Loading model weights")
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            model_weights = ckpt_dict["model"]
            logger.info(f"Epoch checkpoint: {args.ckpt_epoch}")
            load_model_parameters(model, model_weights)
    logger.info(f"Loading model done")

    # Add FC layer to model for fine-tuning or feature extracting
    model = Finetune_Model(model.video_network.base,
                           get_video_dim(vid_base_arch=args.vid_base_arch),
                           NUM_CLASSES[args.dataset],
                           use_dropout=args.use_dropout,
                           use_bn=args.use_bn,
                           use_l2_norm=args.use_l2_norm,
                           dropout=0.7)

    # Create DataParallel model
    model = model.cuda()
    model = torch.nn.DataParallel(model)
    model_without_ddp = model.module

    # Get params for optimization
    params = []
    if args.feature_extract:  # feature_extract only classifer
        for name, param in model_without_ddp.classifier.named_parameters():
            logger.info((name, param.shape))
            params.append({
                'params': param,
                'lr': args.head_lr,
                'weight_decay': args.weight_decay
            })
    else:  # finetune
        for name, param in model_without_ddp.classifier.named_parameters():
            logger.info((name, param.shape))
            params.append({
                'params': param,
                'lr': args.head_lr,
                'weight_decay': args.weight_decay
            })
        for name, param in model_without_ddp.base.named_parameters():
            logger.info((name, param.shape))
            params.append({
                'params': param,
                'lr': args.base_lr,
                'weight_decay': args.wd_base
            })

    logger.info("Creating AV Datasets")
    dataset = AVideoDataset(
        ds_name=args.dataset,
        root_dir=args.root_dir,
        mode='train',
        num_frames=args.clip_len,
        sample_rate=args.steps_bet_clips,
        num_train_clips=args.train_clips_per_video,
        train_crop_size=128 if args.augtype == 1 else 224,
        seed=None,
        fold=args.fold,
        colorjitter=args.colorjitter,
        temp_jitter=True,
        center_crop=False,
        target_fps=30,
        decode_audio=False,
    )
    dataset_test = AVideoDataset(
        ds_name=args.dataset,
        root_dir=args.root_dir,
        mode='test',
        num_frames=args.clip_len,
        sample_rate=args.steps_bet_clips,
        test_crop_size=128 if args.augtype == 1 else 224,
        num_spatial_crops=args.num_spatial_crops,
        num_ensemble_views=args.val_clips_per_video,
        seed=None,
        fold=args.fold,
        colorjitter=args.test_time_cj,
        temp_jitter=True,
        target_fps=30,
        decode_audio=False,
    )

    # Creating dataloaders
    logger.info("Creating data loaders")
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=args.batch_size,
                                              sampler=None,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              drop_last=True,
                                              shuffle=True)
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=args.batch_size,
                                                   sampler=None,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)

    # linearly scale LR and set up optimizer
    if args.optim_name == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.head_lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optim_name == 'adam':
        optimizer = torch.optim.Adam(params,
                                     lr=args.head_lr,
                                     weight_decay=args.weight_decay)

    # Multi-step LR scheduler
    if args.use_scheduler:
        lr_milestones = args.lr_milestones.split(',')
        milestones = [int(lr) - args.lr_warmup_epochs for lr in lr_milestones]
        if args.lr_warmup_epochs > 0:
            scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=args.lr_gamma)
            multiplier = 8
            lr_scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=multiplier,
                total_epoch=args.lr_warmup_epochs,
                after_scheduler=scheduler_step)
        else:  # no warmp, just multi-step
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=milestones, gamma=args.lr_gamma)
    else:
        lr_scheduler = None

    # Checkpointing
    if args.resume:
        ckpt_path = os.path.join(args.output_dir, 'checkpoints',
                                 'checkpoint.pth')
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch']
        logger.info(f"Resuming from epoch: {args.start_epoch}")

    # Only perform evalaution
    if args.test_only:
        scores_val = evaluate(
            model,
            data_loader_test,
            epoch=args.start_epoch,
            writer=writer,
            ds=args.dataset,
        )
        _, vid_acc1, vid_acc5 = scores_val
        return vid_acc1, vid_acc5, args.start_epoch

    start_time = time.time()
    best_vid_acc_1 = -1
    best_vid_acc_5 = -1
    best_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        logger.info(f'Start training epoch: {epoch}')
        scores = train(
            model,
            optimizer,
            data_loader,
            epoch,
            writer=writer,
            ds=args.dataset,
        )
        logger.info(f'Start evaluating epoch: {epoch}')
        lr_scheduler.step()
        scores_val = evaluate(
            model,
            data_loader_test,
            epoch=epoch,
            writer=writer,
            ds=args.dataset,
        )
        _, vid_acc1, vid_acc5 = scores_val
        training_stats.update(scores + scores_val)
        if vid_acc1 > best_vid_acc_1:
            best_vid_acc_1 = vid_acc1
            best_vid_acc_5 = vid_acc5
            best_epoch = epoch
        if args.output_dir:
            logger.info(f'Saving checkpoint to: {args.output_dir}')
            save_checkpoint(args,
                            epoch,
                            model,
                            optimizer,
                            lr_scheduler,
                            ckpt_freq=1)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info(f'Training time {total_time_str}')
    return best_vid_acc_1, best_vid_acc_5, best_epoch
Example #3
0
def main():

    # parse arguments
    global args
    parser = parse_arguments()
    args = parser.parse_args()

    # exp setup: logger, distributed mode and seeds
    init_distributed_mode(args)
    init_signal_handler()
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    if args.rank == 0:
        writer = SummaryWriter(args.dump_path)
    else:
        writer = None

    # build data
    train_dataset = AVideoDataset(
        ds_name=args.ds_name,
        root_dir=args.root_dir,
        mode='train',
        path_to_data_dir=args.data_path,
        num_frames=args.num_frames,
        target_fps=args.target_fps,
        sample_rate=args.sample_rate,
        num_train_clips=args.num_train_clips,
        train_crop_size=args.train_crop_size,
        test_crop_size=args.test_crop_size,
        num_data_samples=args.num_data_samples,
        colorjitter=args.colorjitter,
        use_grayscale=args.use_grayscale,
        use_gaussian=args.use_gaussian,
        temp_jitter=True,
        decode_audio=True,
        aug_audio=None,
        num_sec=args.num_sec_aud,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=args.use_volume_jittering,
        use_temporal_jittering=args.use_audio_temp_jittering,
        z_normalize=args.z_normalize,
        dual_data=args.dual_data
    )
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Loaded data with {} videos.".format(len(train_dataset)))

    # Load model
    model = load_model(
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        use_mlp=args.use_mlp,
        num_classes=args.mlp_dim,
        pretrained=False,
        norm_feat=False,
        use_max_pool=False,
        headcount=args.headcount,
    )

    # synchronize batch norm layers
    if args.sync_bn == "pytorch":
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    elif args.sync_bn == "apex":
        process_group = None
        if args.world_size // 8 > 0:
            process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8)
        model = apex.parallel.convert_syncbn_model(model, process_group=process_group)

    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    if args.use_warmup_scheduler:
        lr_scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=args.world_size,
            total_epoch=args.warmup_epochs,
            after_scheduler=None
        )
    else:
        lr_scheduler = None

    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # SK-Init
    N_dl = len(train_loader)
    N = len(train_loader.dataset)
    N_distr = N_dl * train_loader.batch_size
    selflabels = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
    global sk_schedule
    sk_schedule = (args.epochs * N_dl * (np.linspace(0, 1, args.nopts) ** args.schedulepower)[::-1]).tolist()
    # to make sure we don't make it empty
    sk_schedule = [(args.epochs + 2) * N_dl] + sk_schedule
    logger.info(f'remaining SK opts @ epochs {[np.round(1.0 * t / N_dl, 2) for t in sk_schedule]}')

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0, 'selflabels': selflabels, 'dist':args.dist}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        model=model,
        optimizer=optimizer,
        amp=apex.amp if args.use_fp16 else None,
    )
    start_epoch = to_restore["epoch"]
    selflabels = to_restore["selflabels"]
    args.dist = to_restore["dist"]

    # Set CuDNN benhcmark
    cudnn.benchmark = True

    # Restart schedule correctly
    if start_epoch != 0:
        include = [(qq / N_dl > start_epoch) for qq in sk_schedule]
        # (total number of sk-opts) - (number of sk-opts outstanding)
        global sk_counter
        sk_counter = len(sk_schedule) - sum(include)
        sk_schedule = (np.array(sk_schedule)[include]).tolist()
        if lr_scheduler:
            [lr_scheduler.step() for _ in range(to_restore['epoch'])]

    if start_epoch == 0:
        train_loader.sampler.set_epoch(999)
        warmup_batchnorm(args, model, train_loader, batches=20, group=group)

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)
        if writer:
            writer.add_scalar('train/epoch', epoch, epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores, selflabels = train(
            train_loader, model, optimizer, epoch, writer, selflabels)
        training_stats.update(scores)

        # Update LR scheduler
        if lr_scheduler:
            lr_scheduler.step()

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "dist": args.dist,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "selflabels": selflabels
            }

            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(args.dump_path, "checkpoint.pth.tar"),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(args.dump_path, "checkpoint.pth.tar"),
                    os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth")
                )