示例#1
0
            log_dir = os.path.join(args.log, exp_name)
        writer = SummaryWriter(log_dir)

        train_transforms = transforms.Compose([
            transforms.Resize((128, 171)),
            transforms.RandomCrop(112),
            transforms.ToTensor()
        ])

        if args.dataset == 'ucf101':
            train_dataset = UCF101Dataset('data/ucf101', args.cl, args.split, True, train_transforms)
            # 交叉验证集选取,按比例
            val_size = 800
            train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset) - val_size, val_size))
        elif args.dataset == 'hmdb51':
            train_dataset = HMDB51Dataset('data/hmdb51', args.cl, args.split, True, train_transforms)
            val_size = 400
            train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset) - val_size, val_size))
        elif args.dataset == "monkey2":
            train_dataset = MONKEY2Dataset('data/monkey2', args.cl, args.split, True, train_transforms)
            val_size = 20
            train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset) - val_size, val_size))


        print('TRAIN video number: {}, VAL video number: {}.'.format(len(train_dataset), len(val_dataset)))
        train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
                                      num_workers=args.workers, pin_memory=True)
        val_dataloader = DataLoader(val_dataset, batch_size=args.bs, shuffle=False,
                                    num_workers=args.workers, pin_memory=True)

        if args.ckpt:
示例#2
0
    else:
        exp_name = 'K400_TCG_split1_finetuned_loss_{}_cl{}_{}'.format(args.model, args.cl, time.strftime('%m%d%H%M'))
    log_dir = os.path.join(args.log, exp_name)
    writer = SummaryWriter(log_dir)

    train_transforms = transforms.Compose([
        transforms.Resize((128, 171)),
        transforms.RandomCrop(112),
        transforms.ToTensor()
    ])

    if args.dataset == 'ucf101':
        train_dataset = UCF101Dataset('data/ucf101', args.cl, args.split, True, train_transforms)
        val_size = 800
    elif args.dataset == 'hmdb51':
        train_dataset = HMDB51Dataset('data/hmdb51', args.cl, args.split, True, train_transforms)
        val_size = 400
    elif args.dataset == 'K400':
        train_dataset = K400Dataset_train('data/K400', args.cl, args.split, True, train_transforms)
        val_dataset = K400Dataset_val('data/K400', args.cl, args.split, True, train_transforms)

    # split val for 800 videos
    #train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset)-val_size, val_size))
    print('TRAIN video number: {}, VAL video number: {}.'.format(len(train_dataset), len(val_dataset)))
    train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,
                                num_workers=args.workers, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=args.bs, shuffle=False,
                                num_workers=args.workers, pin_memory=True)

    # save graph and clips_order samples
    for data in train_dataloader:
示例#3
0
def main():
    if not torch.cuda.is_available():
        raise 'Only support GPU mode'
    # parse the args
    args = parse_option()
    print(vars(args))

    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    print('[Warning] The training modalities are RGB and [{}]'.format(
        args.modality))

    # Data
    train_transforms = transforms.Compose([
        transforms.Resize((128, 171)),  # smaller edge to 128
        transforms.RandomCrop(112),
        transforms.ToTensor()
    ])
    if args.dataset == 'ucf101':
        trainset = UCF101Dataset('./data/ucf101/',
                                 transforms_=train_transforms)
    else:
        trainset = HMDB51Dataset('./data/hmdb51/',
                                 transforms_=train_transforms)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)

    n_data = trainset.__len__()

    # set the model
    model, contrast, criterion_1, criterion_2 = set_model(args, n_data)

    # set the optimizer
    optimizer = set_optimizer(args, model)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[45, 90, 125, 160],
                                         gamma=0.2)
    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):
        time1 = time.time()
        view1_loss, view1_prob, view2_loss, view2_prob = train(
            epoch, train_loader, model, contrast, criterion_1, criterion_2,
            optimizer, args)
        time2 = time.time()
        print('\nepoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('view1_loss', view1_loss, epoch)
        logger.log_value('view1_prob', view1_prob, epoch)
        logger.log_value('view2_loss', view2_loss, epoch)
        logger.log_value('view2_prob', view2_prob, epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        torch.cuda.empty_cache()
        scheduler.step()

    print(args.model_name)