예제 #1
0
def main():
    args = parse_args()
    print(vars(args))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.model_name == 'c3d':
        model = c3d.C3D(with_classifier=True, num_classes=101)
    elif args.model_name == 'r3d':
        model = r3d.R3DNet((1, 1, 1, 1), with_classifier=True, num_classes=101)
    elif args.model_name == 'r21d':
        model = r21d.R2Plus1DNet((1, 1, 1, 1),
                                 with_classifier=True,
                                 num_classes=101)
    print(args.model_name)

    start_epoch = 1
    pretrain_path = pretrain_path_list[args.pre_path]
    print(pretrain_path)
    pretrain_weight = load_pretrained_weights(pretrain_path)
    print(pretrain_weight.keys())
    model.load_state_dict(pretrain_weight, strict=False)
    # train
    train_dataset = ClassifyDataSet(params['dataset'], mode="train")
    if params['data'] == 'UCF-101':
        val_size = 800
    elif params['data'] == 'hmdb':
        val_size = 400
    train_dataset, val_dataset = random_split(
        train_dataset, (len(train_dataset) - val_size, val_size))

    print("num_works:{:d}".format(params['num_workers']))
    print("batch_size:{:d}".format(params['batch_size']))
    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'],
                              shuffle=True,
                              num_workers=params['num_workers'])
    val_loader = DataLoader(val_dataset,
                            batch_size=params['batch_size'],
                            shuffle=True,
                            num_workers=params['num_workers'])
    if multi_gpu == 1:
        model = nn.DataParallel(model)
    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=params['learning_rate'],
                          momentum=params['momentum'],
                          weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    save_path = params['save_path_base'] + "ft_classify_{}_".format(
        args.exp_name) + params['data']
    model_save_dir = os.path.join(save_path, time.strftime('%m-%d-%H-%M'))

    writer = SummaryWriter(model_save_dir)

    #     for data in train_loader:
    #         clip , label = data;
    #         writer.add_video('train/clips',clip,0,fps=8)
    #         writer.add_text('train/idx',str(label.tolist()),0)
    #         clip = clip.cuda()
    #         writer.add_graph(model,(clip,clip));
    #         break
    #     for name,param in model.named_parameters():
    #         writer.add_histogram('params/{}'.format(name),param,0);

    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    prev_best_val_loss = float('inf')
    prev_best_loss_model_path = None
    prev_best_acc_model_path = None
    best_acc = 0
    best_epoch = 0
    for epoch in tqdm(range(start_epoch, start_epoch + params['epoch_num'])):
        scheduler.step()
        train(train_loader, model, criterion, optimizer, epoch, writer)
        val_loss, top1_avg = validation(val_loader, model, criterion,
                                        optimizer, epoch)
        if top1_avg >= best_acc:
            best_acc = top1_avg
            print("i am best :", best_acc)
            best_epoch = epoch
            model_path = os.path.join(
                model_save_dir, 'best_acc_model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(), model_path)
#             if prev_best_acc_model_path:
#                 os.remove(prev_best_acc_model_path)
#             prev_best_acc_model_path = model_path
        if val_loss < prev_best_val_loss:
            model_path = os.path.join(
                model_save_dir, 'best_loss_model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(), model_path)
            prev_best_val_loss = val_loss


#             if prev_best_loss_model_path:
#                 os.remove(prev_best_loss_model_path)
#             prev_best_loss_model_path = model_path
#         scheduler.step(val_loss);
        if epoch % 20 == 0:
            checkpoints = os.path.join(model_save_dir, str(epoch) + ".pth.tar")
            torch.save(model.state_dict(), checkpoints)
            print("save_to:", checkpoints)
    print("best is :", best_acc, best_epoch)
예제 #2
0
def train(args):
    torch.backends.cudnn.benchmark = True

    exp_name = '{}_sr_{}_{}_lr_{}_len_{}_sz_{}'.format(args.dataset,
                                                       args.max_sr, args.model,
                                                       args.lr, args.clip_len,
                                                       args.crop_sz)

    print(exp_name)

    pretrain_cks_path = os.path.join('pretrain_cks', exp_name)
    log_path = os.path.join('visual_logs', exp_name)

    if not os.path.exists(pretrain_cks_path):
        os.makedirs(pretrain_cks_path)

    if not os.path.exists(log_path):
        os.makedirs(log_path)

    ## 1. dataset
    #data_list = '/home/user/code/ucf101_list_ab1/train_ucf101_num_frames.list'
    #rgb_prefix = '/home/user/dataset/ucf101_jpegs_256/'

    transforms_ = transforms.Compose([
        ClipResize((args.height, args.width)),  # h x w
        RandomCrop(args.crop_sz),
        RandomHorizontalFlip(0.5)
    ])

    color_jitter = transforms.ColorJitter(brightness=0.8,
                                          contrast=0.8,
                                          saturation=0.8,
                                          hue=0.2)
    color_jitter = transforms.RandomApply([color_jitter], p=0.8)

    train_dataset = ucf101_pace_pretrain(args.data_list,
                                         args.rgb_prefix,
                                         clip_len=args.clip_len,
                                         max_sr=args.max_sr,
                                         transforms_=transforms_,
                                         color_jitter_=color_jitter)

    print("len of training data:", len(train_dataset))
    dataloader = DataLoader(train_dataset,
                            batch_size=args.bs,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)

    ## 2. init model
    if args.model == 'r21d':
        model = r21d.R2Plus1DNet(num_classes=args.num_classes)
    elif args.model == 'r3d':
        model = r3d.R3DNet(num_classes=args.num_classes)
    elif args.model == 'c3d':
        model = c3d.C3D(num_classes=args.num_classes)
    elif args.model == 's3d':
        model = s3d_g.S3D(num_classes=args.num_classes, space_to_depth=False)

    # 3. define loss and lr
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=0.005)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

    # 4. multi gpu
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model.to(device)
    criterion.to(device)

    writer = SummaryWriter(log_dir=log_path)
    iterations = 1

    model.train()

    for epoch in range(args.epoch):
        start_time = time.time()

        for i, sample in enumerate(dataloader):
            rgb_clip, labels = sample
            rgb_clip = rgb_clip.to(device, dtype=torch.float)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(rgb_clip)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            probs = nn.Softmax(dim=1)(outputs)
            preds = torch.max(probs, 1)[1]
            accuracy = torch.sum(
                preds == labels.data).detach().cpu().numpy().astype(np.float)
            accuracy = accuracy / args.bs

            iterations += 1

            if i % args.pf == 0:
                writer.add_scalar('data/train_loss', loss, iterations)
                writer.add_scalar('data/Acc', accuracy, iterations)

                print("[Epoch{}/{}] Loss: {} Acc: {} Time {} ".format(
                    epoch + 1, i, loss, accuracy,
                    time.time() - start_time))

            start_time = time.time()

        scheduler.step()
        model_saver(model, optimizer, epoch, args.max_save, pretrain_cks_path)

    writer.close()
예제 #3
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet', 'c3d'
    ]

    if opt.model == 'c3d':
        model = c3d.C3D(num_classes=opt.n_classes)

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
예제 #4
0
def main():
    args = parse_args()
    print(vars(args))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if args.model_name == 'c3d':
        model=c3d.C3D(with_classifier=False)
    elif args.model_name == 'r3d':
        model=r3d.R3DNet((1,1,1,1),with_classifier=False)
    elif args.model_name == 'r21d':
        model=r21d.R2Plus1DNet((1,1,1,1),with_classifier=False)
    print(args.model_name)
    model = sscn.SSCN_OneClip(base_network=model, with_classifier=True, num_classes=4)
    if ckpt:
        weight = load_pretrained_weights(ckpt)
        model.load_state_dict(weight, strict=True)
    #train
    train_dataset =PredictDataset(params['dataset'],mode="train",args=args);
    # if params['data'] =='kinetics-400':
    #     val_dataset = PredictDataset(params['dataset'],mode='val',args=args);
    if params['data'] == 'UCF-101':
        val_size = 800
        train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset) - val_size, val_size))
    elif params['data'] == 'hmdb':
        val_size = 400
        train_dataset, val_dataset = random_split(train_dataset, (len(train_dataset) - val_size, val_size))

    train_loader = DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True,num_workers=params['num_workers'],drop_last=True)
    val_loader = DataLoader(val_dataset,batch_size=params['batch_size'],shuffle=True,num_workers=params['num_workers'],drop_last=True)
    if multi_gpu ==1:
        model = nn.DataParallel(model)
    model = model.cuda()
    criterion_CE = nn.CrossEntropyLoss().cuda()
    # criterion_MSE = Motion_MSEloss().cuda()
    criterion_MSE = Motion_MSEloss_NoFakeGt().cuda()

    model_params = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'fc8' in key:
                print(key)
                model_params += [{'params':[value],'lr':10*learning_rate}]
            else:
                model_params += [{'params':[value],'lr':learning_rate}]
    optimizer = optim.SGD(model_params, momentum=params['momentum'], weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr=1e-7, patience=50, factor=0.1)

    save_path = params['save_path_base'] + "train_predict_{}_".format(args.exp_name) + params['data']
    model_save_dir = os.path.join(save_path,time.strftime('%m-%d-%H-%M'))
    writer = SummaryWriter(model_save_dir)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    prev_best_val_loss = 100
    prev_best_loss_model_path = None
    for epoch in tqdm(range(start_epoch,start_epoch+train_epoch)):
        train(train_loader,model,criterion_MSE,criterion_CE,optimizer,epoch,writer,root_path=model_save_dir)
        val_loss = validation(val_loader,model,criterion_MSE,criterion_CE,optimizer,epoch)
        if val_loss < prev_best_val_loss:
            model_path = os.path.join(model_save_dir, 'best_model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(), model_path)
            prev_best_val_loss = val_loss;
            if prev_best_loss_model_path:
                os.remove(prev_best_loss_model_path)
            prev_best_loss_model_path = model_path
        scheduler.step(val_loss);

        if epoch % 20 == 0:
            checkpoints = os.path.join(model_save_dir, 'model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(),checkpoints)
            print("save_to:",checkpoints);
예제 #5
0
    model.load_state_dict(pretrain_weight,strict= True)
#     model.load_state_dict(torch.load(pretrain_path, map_location='cpu'), strict=True)
    test_dataset = ClassifyDataSet(params['dataset'], mode="test");
    test_loader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False,
                             num_workers=params['num_workers'])

    if len(device_ids)>1:
        print(torch.cuda.device_count())
        model = nn.DataParallel(model)
    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    test(test_loader, model, criterion,pretrain_path)
if __name__ == '__main__':
    print(1)
    seed = 632
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    model=c3d.C3D(with_classifier=True, num_classes=101);
#     model=r3d.R3DNet((1,1,1,1),with_classifier=True, num_classes=101)
#     model=r21d.R2Plus1DNet((1,1,1,1),with_classifier=True, num_classes=101)


    pretrain_path = './outputs/ft_classify_Finsert_rate2_1248_part_patch_UCF-101/11-10-23-37/best_loss_model_139.pth.tar'
    test_model(model,pretrain_path)



예제 #6
0
# pretrain="resnet50_ucf101_701.pth"
# state = torch.load(pretrain)
# sd = state["state_dict"]
# if torch.cuda.is_available():
#     mdl = torch.nn.DataParallel(mdl)  # Hara et. al. normalised by mean only
# mdl.load_state_dict(state_dict=sd, strict=False)
# first_lin = True
# if hasattr(mdl, "module"):
#     module = mdl.module
# else:
#     module = mdl

#         first_lin = False
from models import c3d
import collections
mdl = c3d.C3D(101, range=(-max(ds_mean), 255-min(ds_mean)))
if hasattr(mdl, "module"):
     module = mdl.module
else:
    module = mdl

state = torch.load('save_20.pth')
n_state = []
for k, v in state['state_dict'].items():
    if 'module' in k:
        n_state.append((k[7:], v))
if n_state == []:
    n_state = state['state_dict']
else:
    n_state = collections.OrderedDict(n_state)
module.load_state_dict(n_state)
예제 #7
0
def generate_model(opt):
    assert opt.model in [
        'resnet',
        'preresnet',
        'wideresnet',
        'resnext',
        'densenet',
        'c3d',
        'c2d',
        'c2d_exp',
        'c2d_coord',
        'c3d_color',
        'c2d_pt',
        'c2d_pt2',
        'c2d_pt5',
        'c2d_pt7',
        'c2d_pt_exp',
        'c2d_pt2_exp',
        'c2d_pt5_exp',
        'c2d_pt_exp_avg',
        'c2d_pt_exp_sep',
        'c3d_pt_exp',
        'c2d_pt_exp_init',
        'c2d_pt_expc',
        'resnet18_exp',
        'resnet34_exp',
        'resnet50_exp',
        'resnet101_exp',
        'resnet152_exp',
        'resnext50_32x4d_exp',
        'resnext101_32x8d_exp',
        'wide_resnet50_2_exp',
        'wide_resnet101_2_exp',
        'resnet18_pt_exp',
        'resnet34_pt_exp',
        'resnet50_pt_exp',
        'resnet101_pt_exp',
        'resnet152_pt_exp',
        'resnext50_32x4d_pt_exp',
        'resnext101_32x8d_pt_exp',
        'wide_resnet50_2_pt_exp',
        'wide_resnet101_2_pt_exp',
        # decoder
        'stsrresnetexp',
        'spc',
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'c3d':
        model = c3d.C3D(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c3d_color':
        model = c3d_color.C3D(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'spc':
        model = spc.SPC(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c2d':
        model = c2d.C2D(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt':
        model = c2d_pt.C2DPt(num_classes=opt.n_classes,
                             sample_size=opt.sample_size,
                             sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt2':
        model = c2d_pt2.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt5':
        model = c2d_pt5.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt7':
        model = c2d_pt7.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_exp':
        model = c2d_exp.C2DExp(num_classes=opt.n_classes,
                               sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp':
        model = c2d_pt_exp.C2DPtExp(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_expc':
        model = c2d_pt_expc.C2DPtExpC(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_init':
        model = c2d_pt_exp_init.C2DPtExp(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'c3d_pt_exp':
        model = c3d_pt_exp.C3DPtExp(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_avg':
        model = c2d_pt_exp_avg.C2DPtExpAvg(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_sep':
        model = c2d_pt_exp_sep.C2DPtExpSep(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt5_exp':
        model = c2d_pt5_exp.C2DPtExp(num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt2_exp':
        model = c2d_pt2_exp.C2DPtExp(num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_coord':
        model = c2d_coord.C2DCoord(num_classes=opt.n_classes,
                                   sample_size=opt.sample_size,
                                   sample_duration=opt.sample_duration)

    elif opt.model == 'resnet18_exp':
        model = resnet_exp.resnet18(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet34_exp':
        model = resnet_exp.resnet34(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet50_exp':
        model = resnet_exp.resnet50(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet101_exp':
        model = resnet_exp.resnet101(pretrained=False,
                                     progress=True,
                                     num_classes=opt.n_classes,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnet152_exp':
        model = resnet_exp.resnet152(pretrained=False,
                                     progress=True,
                                     num_classes=opt.n_classes,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnext50_32x4d_exp':
        model = resnet_exp.resnext50_32x4d(pretrained=False,
                                           progress=True,
                                           num_classes=opt.n_classes,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'resnext101_32x8d_exp':
        model = resnet_exp.resnext101_32x8d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet50_2_exp':
        model = resnet_exp.wide_resnet50_2(pretrained=False,
                                           progress=True,
                                           num_classes=opt.n_classes,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet101_2_exp':
        model = resnet_exp.wide_resnet101_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)

    elif opt.model == 'resnet18_pt_exp':
        model = resnet_pt_exp.resnet18(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet34_pt_exp':
        model = resnet_pt_exp.resnet34(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet50_pt_exp':
        model = resnet_pt_exp.resnet50(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet101_pt_exp':
        model = resnet_pt_exp.resnet101(pretrained=False,
                                        progress=True,
                                        num_classes=opt.n_classes,
                                        sample_duration=opt.sample_duration)
    elif opt.model == 'resnet152_pt_exp':
        model = resnet_pt_exp.resnet152(pretrained=False,
                                        progress=True,
                                        num_classes=opt.n_classes,
                                        sample_duration=opt.sample_duration)
    elif opt.model == 'resnext50_32x4d_pt_exp':
        model = resnet_pt_exp.resnext50_32x4d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'resnext101_32x8d_pt_exp':
        model = resnet_pt_exp.resnext101_32x8d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet50_2_pt_exp':
        model = resnet_pt_exp.wide_resnet50_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet101_2_pt_exp':
        model = resnet_pt_exp.wide_resnet101_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)

    elif opt.model == 'stsrresnetexp':
        model = decoder.STSRResNetExp(sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
예제 #8
0
def main():
    args = parse_args()
    #     torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    pretrain_path = pretrain_path_list[args.pre_path]
    save_path = params['save_path_base'] + "ft3_classify_{}_{}_".format(
        pretrain_path.split('/')[-3][14:],
        args.exp_name) + params['data'] + '_split{}'.format(args.split)
    sub_dir = 'pt-{}-e{}-ft-{}'.format(
        pretrain_path.split('/')[-2],
        pretrain_path.split('/')[-1].split('.')[0].split('_')[-1],
        time.strftime('%m-%d-%H-%M'))
    model_save_dir = os.path.join(save_path, sub_dir)
    writer = SummaryWriter(model_save_dir)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    log_file = os.path.join(model_save_dir, 'log.txt')
    sys.stdout = Logger(log_file)
    print(vars(args))

    if params['data'] == 'UCF-101':
        class_num = 101
    elif params['data'] == 'HMDB-51':
        class_num = 51
    print('{}: {}'.format(params['data'], class_num))

    if args.model_name == 'c3d':
        model = c3d.C3D(with_classifier=True, num_classes=class_num)
    elif args.model_name == 'r3d':
        model = r3d.R3DNet((1, 1, 1, 1),
                           with_classifier=True,
                           num_classes=class_num)
    elif args.model_name == 'r21d':
        model = r21d.R2Plus1DNet((1, 1, 1, 1),
                                 with_classifier=True,
                                 num_classes=class_num)
    print('Backbone:{}'.format(args.model_name))

    start_epoch = 1
    pretrain_path = pretrain_path_list[args.pre_path]
    print('Load model:' + pretrain_path)
    pretrain_weight = load_pretrained_weights(pretrain_path)
    print(pretrain_weight.keys())
    model.load_state_dict(pretrain_weight, strict=False)
    # train
    image_augmentation = None
    video_augmentation = transforms.Compose([
        video_transforms.ToPILImage(),
        video_transforms.Resize((128, 171)),
        video_transforms.RandomCrop(112),
        video_transforms.ToTensor()
    ])

    train_dataset = ClassifyDataSet(params['dataset'],
                                    mode="train",
                                    split=args.split,
                                    dataset=params['data'],
                                    video_transforms=video_augmentation,
                                    image_transforms=image_augmentation)
    if params['data'] == 'UCF-101':
        val_size = 800
    elif params['data'] == 'HMDB-51':
        val_size = 400
    train_dataset, val_dataset = random_split(
        train_dataset, (len(train_dataset) - val_size, val_size))

    print("num_works:{:d}".format(params['num_workers']))
    print("batch_size:{:d}".format(params['batch_size']))
    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'],
                              shuffle=True,
                              num_workers=params['num_workers'])
    val_loader = DataLoader(val_dataset,
                            batch_size=params['batch_size'],
                            shuffle=True,
                            num_workers=params['num_workers'])
    if multi_gpu == 1:
        model = nn.DataParallel(model)
    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=params['learning_rate'],
                          momentum=params['momentum'],
                          weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    #     for data in train_loader:
    #         clip , label = data;
    #         writer.add_video('train/clips',clip,0,fps=8)
    #         writer.add_text('train/idx',str(label.tolist()),0)
    #         clip = clip.cuda()
    #         writer.add_graph(model,(clip,clip));
    #         break
    #     for name,param in model.named_parameters():
    #         writer.add_histogram('params/{}'.format(name),param,0);

    prev_best_val_loss = float('inf')
    prev_best_loss_model_path = None
    prev_best_acc_model_path = None
    best_acc = 0
    best_epoch = 0
    for epoch in tqdm(range(start_epoch, start_epoch + params['epoch_num'])):
        scheduler.step()
        train(train_loader, model, criterion, optimizer, epoch, writer)
        val_loss, top1_avg = validation(val_loader, model, criterion,
                                        optimizer, epoch)
        if top1_avg >= best_acc:
            best_acc = top1_avg
            print("i am best :", best_acc)
            best_epoch = epoch
            model_path = os.path.join(
                model_save_dir, 'best_acc_model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(), model_path)
#             if prev_best_acc_model_path:
#                 os.remove(prev_best_acc_model_path)
#             prev_best_acc_model_path = model_path
        if val_loss < prev_best_val_loss:
            model_path = os.path.join(
                model_save_dir, 'best_loss_model_{}.pth.tar'.format(epoch))
            torch.save(model.state_dict(), model_path)
            prev_best_val_loss = val_loss


#             if prev_best_loss_model_path:
#                 os.remove(prev_best_loss_model_path)
#             prev_best_loss_model_path = model_path
#         scheduler.step(val_loss);
        if epoch % 20 == 0:
            checkpoints = os.path.join(model_save_dir, str(epoch) + ".pth.tar")
            torch.save(model.state_dict(), checkpoints)
            print("save_to:", checkpoints)
    print("best is :", best_acc, best_epoch)
예제 #9
0
def main():
    # get filenames and labels from text file
    dataset, labels = read_file()

    # load data
    transform = transforms.Compose([transforms.ToTensor()])
    num_frames = 8

    ucf101 = UCFDataset(dataset, labels, args.num_frames, transform)

    test_loader = torch.utils.data.DataLoader(ucf101,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=4)

    # get model
    if args.model == 'c3d':
        model = c3d.C3D()
    # elif args.model == 'resnet_3d':
    #    model = resnet_3d.resnet_3d(num_classes, input_shape, drop_rate=args.drop_rate)
    # elif args.model == 'densenet_3d':
    #    model = densenet_3d.densenet_3d(num_classes, input_shape, dropout_rate=args.drop_rate)
    # elif args.model == 'inception_3d':
    #    model = inception_3d.inception_3d(num_classes, input_shape)
    # elif args.model == 'dense_resnet_3d':
    #    model = DenseResNet_3d.dense_resnet_3d(num_classes, input_shape, dropout_rate=args.drop_rate)

    # optimizer
    sgd = torch.optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          nesterov=True)

    # loss function
    cross_entropy = torch.nn.CrossEntropyLoss()

    for epoch in range(1, args.epochs + 1):
        losses = []
        correct = 0
        count = 0

        for batch_idx, data_point in enumerate(test_loader):
            if count < args.batch_size:
                if count == 0:
                    batch = data_point['frames'].clone()
                    labels = data_point['label'].clone()
                else:
                    batch = torch.cat((batch, data_point['frames']), dim=0)
                    labels = torch.cat((labels, data_point['label']), dim=0)
                count += 1
                if batch_idx < len(test_loader.dataset) - 1:
                    continue

            # batch created, start training
            model.train().cuda()
            batch = batch.cuda()
            labels = labels.cuda()

            # reset optimizer
            sgd.zero_grad()

            # run through model
            output = model(batch)

            # loss
            loss = cross_entropy(output, labels)
            loss.backward()
            losses.append(loss.item())

            # update weights
            sgd.step()

            # get accuracy
            prediction = output.argmax(dim=1, keepdim=True)
            correct += prediction.eq(labels.view_as(prediction)).sum().item()

            # reset count to 0
            count = 0

        # calculate loss
        train_loss = float(np.mean(losses))
        train_acc = float(100 * correct) / len(test_loader.dataset)

        print(train_loss, train_acc)