Beispiel #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, train_list, val_list, data_root_path, prefix = datasets_video.return_dataset(
        args.root_path, args.dataset, args.modality)

    num_class = len(categories)

    architecture = architecture_name_parser(args.architecture)

    store_name = '_'.join([
        'MFF', args.dataset, args.modality, architecture,
        'segment%d' % args.num_segments,
        '%df1c' % args.num_motion
    ])
    print('storing name: ' + store_name)

    print("Using " + architecture + " architecture")

    tsn = TSN(num_class,
              args.num_segments,
              args.modality,
              architecture=architecture,
              consensus_type=args.consensus_type,
              dropout=args.dropout,
              num_motion=args.num_motion,
              img_feature_dim=args.img_feature_dim,
              partial_bn=args.partialbn,
              dataset=args.dataset,
              group_norm=args.group_norm)

    model = tsn.total_model
    model.summary()

    #TODO: group normalize for non RGBDiff or RGBFlow

    transform_fn = modules[architecture].preprocess_input

    # define loss function (criterion) and optimizer
    if args.loss_type == 'cce':
        criterion = 'categorical_crossentropy'
    else:
        raise ValueError("Unknown loss type")

    # Create optimizer
    optimizer = optimizers.SGD(lr=args.lr,
                               momentum=args.momentum,
                               decay=args.weight_decay)

    if 'adam' in args.optimizer.lower():
        optimizer = optimizers.Adam(lr=args.lr,
                                    decay=args.weight_decay,
                                    clipnorm=args.clip_gradient)
    if 'rms' in args.optimizer.lower():
        optimizer = optimizers.RMSprop(lr=args.lr, decay=args.weight_decay)

    # #######
    model.compile(optimizer, loss=criterion, metrics=['accuracy'])

    if args.experiment_name:
        log_dir = './logs/' + args.experiment_name
    else:
        log_dir = './logs'

    tb = TensorBoard(log_dir=log_dir, batch_size=args.batch_size)
    checkpoint = ModelCheckpoint("model/" + store_name, save_best_only=True)

    train(model, args, optimizer, num_class, data_root_path, tsn.image_dim,
          train_list, val_list, [tb, checkpoint], transform_fn)
Beispiel #2
0
def main():
    global args, best_loss
    args = parser.parse_args()
    check_rootfolders()

    root_data, train_dict, val_dict = datasets_video.return_dataset(
        args.dataset, args.modality, args.view)
    num_class = 1

    args.store_name = '_'.join([
        'TRN', args.label_name, args.dataset, args.modality, args.view,
        args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)
    img_tmpl = '{:06d}.jpg' if args.view == 'body' else 'frame_det_00_{:06d}.bmp'
    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                fix_all_weights=args.fix_all_weights)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()  # augmentation increase

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        root_data,
        args.train_dict,
        args.label_name,
        num_segments=args.num_segments,
        phase='Train',
        new_length=data_length,
        modality=args.modality,
        image_tmpl=img_tmpl,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        root_data,
        args.val_dict,
        args.label_name,
        num_segments=args.num_segments,
        phase='Validation',
        new_length=data_length,
        modality=args.modality,
        image_tmpl=img_tmpl,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif args.loss_type == 'mse':
        criterion = torch.nn.MSELoss().cuda()
    elif args.loss_type == 'mae':
        criterion = torch.nn.SmoothL1Loss().cuda()
    else:  # another loss is mse or mae
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        log_val = open(
            os.path.join(args.root_log, '%s_val.csv' % args.store_name), 'w')
        validate(val_loader, model, criterion, 0, log_val)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            loss = validate(val_loader, model, criterion,
                            (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = loss < best_loss
            best_loss = min(loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                }, is_best)
Beispiel #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(339,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)
    _, cnn = list(model.named_children())[0]
    for p in cnn.parameters():
        p.requires_grad = False

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # remove if not transfer learning
    checkpoint = torch.load('/home/ec2-user/mit_weights.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])

    for module in list(
            list(model._modules['module'].children())
        [-1].children())[-1].children():
        module[-1] = nn.Linear(256, num_class)

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True
    model.cuda()

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                8,
                                                gamma=0.1,
                                                last_epoch=-1)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training,
                             epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
    summary_writer.close()
Beispiel #4
0
    for folder in folders_util:
        if not os.path.exists(folder):
            print('creating folder ' + folder)
            os.mkdir(folder)


if __name__ == '__main__':

    best_prec1 = 0

    global args
    args = parser.parse_args()
    print(args)
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'MFF', args.dataset, args.modality, args.arch,
        'segment%d' % args.num_segments,
        '%df1c' % args.num_motion
    ])

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                num_motion=args.num_motion,
Beispiel #5
0
def main():
    global args, best_prec1, num_train_dataset, num_val_dataset, writer
    args = parser.parse_args()
    # if args.no_cudnn:
    #     torch.backends.cudnn.benchmark = False
    # print (torch.backends.cudnn.benchmark)
    # asdf
    _fill_in_None_args()
    _join_result_path()
    check_rootfolders()
    with open(os.path.join(args.result_path, 'opts.json'), 'w') as opt_file:
        json.dump(vars(args), opt_file)

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality, args.root_path, args.file_type)
    # print(categories, args.train_list, args.val_list, args.root_path, prefix)
    num_class = len(categories)


    args.store_name = '_'.join([args.consensus_type, args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments, \
        'key%d'%args.key_dim, 'value%d'%args.value_dim, 'query%d'%args.query_dim, 'queryUpdateby%s'%args.query_update_method,\
        'NoSoftmax%s'%args.no_softmax_on_p, 'hopMethod%s'%args.hop_method])
    print('storing name: ' + args.store_name)

    model = TSN(
        num_class,
        args.num_segments,
        args.modality,
        base_model=args.arch,
        consensus_type=args.consensus_type,
        dropout=args.dropout,
        key_dim=args.key_dim,
        value_dim=args.value_dim,
        query_dim=args.query_dim,
        query_update_method=args.query_update_method,
        partial_bn=not args.no_partialbn,
        freezeBN_Eval=args.freezeBN_Eval,
        freezeBN_Require_Grad_True=args.freezeBN_Require_Grad_True,
        num_hop=args.hop,
        hop_method=args.hop_method,
        num_CNNs=args.num_CNNs,
        no_softmax_on_p=args.no_softmax_on_p,
        freezeBackbone=args.freezeBackbone,
        CustomPolicy=args.CustomPolicy,
        sorting=args.sorting,
        MultiStageLoss=args.MultiStageLoss,
        MultiStageLoss_MLP=args.MultiStageLoss_MLP,
        how_to_get_query=args.how_to_get_query,
        only_query=args.only_query,
        CC=args.CC,
        channel=args.channel,
        memory_dim=args.memory_dim,
        image_resolution=args.image_resolution,
        how_many_objects=args.how_many_objects,
        Each_Embedding=args.Each_Embedding,
        Curriculum=args.Curriculum,
        Curriculum_dim=args.Curriculum_dim,
        lr_steps=args.lr_steps,
    )

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))
    # asdf
    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_data = TSNDataSet(
        args.root_path,
        args.train_list,
        args.file_type,
        num_segments=args.num_segments,
        MoreAug_Rotation=args.MoreAug_Rotation,
        MoreAug_ColorJitter=args.MoreAug_ColorJitter,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        phase='train',
        transform1=torchvision.transforms.Compose([
            train_augmentation,  # GroupMultiScaleCrop[1, .875, .75, .66] AND GroupRandomHorizontalFlip
        ]),
        transform2=torchvision.transforms.Compose([
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,  # GroupNormalize
        ]),
        image_resolution=args.image_resolution)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_data = TSNDataSet(
        args.root_path,
        args.val_list,
        args.file_type,
        num_segments=args.num_segments,
        MoreAug_Rotation=args.MoreAug_Rotation,
        MoreAug_ColorJitter=args.MoreAug_ColorJitter,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        phase='test',
        transform1=torchvision.transforms.Compose(
            [GroupScale(int(scale_size)),
             GroupCenterCrop(crop_size)]),
        transform2=torchvision.transforms.Compose([
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        image_resolution=args.image_resolution)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False,
                                             drop_last=True)
    num_train_dataset = len(train_data)
    num_val_dataset = len(val_data)

    # print (num_train_dataset, num_val_dataset)
    # print (len(train_loader), len(val_loader))
    # asdf

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss(reduce=False).cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(policies,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(policies,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
        # optimizer = torch.optim.SGD(policies,
        #                             args.lr,
        #                             momentum=args.momentum,
        #                             weight_decay=args.weight_decay)

    if args.evaluate:
        json_file_path = os.path.join(
            args.result_path, 'results_epoch%d.json' % args.evaluation_epoch)
        validate(val_loader,
                 model,
                 criterion,
                 0,
                 json_file=json_file_path,
                 idx2class=categories,
                 epoch=args.evaluation_epoch)
        return

    writer = SummaryWriter(args.result_path)
    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'a')
    # print (count_parameters(model))
    # asdf

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            json_file_path = os.path.join(args.result_path,
                                          'results_epoch%d.json' % (epoch + 1))
            # prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log=log_training, json_file=json_file_path, idx2class=categories)
            prec1 = validate(val_loader,
                             model,
                             criterion, (epoch + 1) * num_train_dataset,
                             log=log_training,
                             json_file=json_file_path,
                             idx2class=categories,
                             epoch=epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
    log_training.close()
    writer.close()
Beispiel #6
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'something-v1':
        num_class = 174
    elif args.dataset == 'diving48':
        num_class = 48
    elif args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'skating2':
        num_class = 63
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model_dir = os.path.join('experiments', args.dataset, args.arch,
                             args.consensus_type + '-' + args.modality,
                             str(args.run_iter))
    args.train_list, args.val_list, args.root_path, args.rgb_prefix = datasets_video.return_dataset(
        args.dataset)
    if 'something' in args.dataset:
        # label transformation for left/right categories
        target_transforms = {
            86: 87,
            87: 86,
            93: 94,
            94: 93,
            166: 167,
            167: 166
        }
        print('Target transformation is enabled....')
    else:
        target_transforms = None

    if not args.resume_rgb:
        if os.path.exists(model_dir):
            print('Dir {} exists!!!  it will be removed'.format(model_dir))
            shutil.rmtree(model_dir)
        os.makedirs(model_dir)
        os.makedirs(os.path.join(model_dir, args.root_log))

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['flow', 'RGBDiff']:
        data_length = 5
        # data_length = 1

    if args.resume_rgb:
        if args.modality == 'RGB':
            if 'gst' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='GST',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'stm' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='STM',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'tmp' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='TMP',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'tsm' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='TSM',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'ori' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='ORI',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'I3D' in args.arch:
                print("!!!!!!!!!!!!!!!!!!!!!!!\n\n")
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='I3D',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)

            else:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='ORI',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            if os.path.isfile(args.resume_rgb):
                print(("=> loading checkpoint '{}'".format(args.resume_rgb)))
                checkpoint = torch.load(args.resume_rgb)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                original_checkpoint = checkpoint['state_dict']
                print(("(epoch {} ) best_prec1 : {} ".format(
                    checkpoint['epoch'], best_prec1)))
                original_checkpoint = {
                    k[7:]: v
                    for k, v in original_checkpoint.items()
                }
                #model_dict =  i3d_model.state_dict()
                #model_dict.update(pretrained_dict)
                model.load_state_dict(original_checkpoint)
                print(
                    ("=> loaded checkpoint '{}' (epoch {} ) best_prec1 : {} ".
                     format(args.resume_rgb, checkpoint['epoch'], best_prec1)))
            else:
                raise ValueError("=> no checkpoint found at '{}'".format(
                    args.resume_rgb))
    else:
        if args.modality == 'flow':
            if 'I3D' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='I3D',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi,
                                      modality='flow',
                                      new_length=data_length)
        elif args.modality == 'RGB':
            if 'gst' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='GST',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'stm' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='STM',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'tmp' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='TMP',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'tsm' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='TSM',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'ori' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='ORI',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            elif 'I3D' in args.arch:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='I3D',
                                      backbone=args.arch,
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)
            else:
                model = TemporalModel(num_class,
                                      args.num_segments,
                                      model='ORI',
                                      backbone=args.arch + '_ori',
                                      alpha=args.alpha,
                                      beta=args.beta,
                                      dropout=args.dropout,
                                      target_transforms=target_transforms,
                                      resi=args.resi)

    cudnn.benchmark = True
    writer = SummaryWriter(model_dir)
    # Data loading code
    args.store_name = '_'.join([
        args.dataset, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = get_optim_policies(model)
    train_augmentation = get_augmentation(mode='train')
    val_trans = get_augmentation(mode='val')
    normalize = GroupNormalize(input_mean, input_std)

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    if args.dataset == 'diving48':
        args.root_path = args.root_path + '/train'

    train_loader = torch.utils.data.DataLoader(VideoDataset(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dataset=args.dataset),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    print("trainloader.type = {}".format(type(train_loader)))
    if args.dataset == 'diving48':
        args.root_path = args.root_path[:-6] + '/test'
    val_loader = torch.utils.data.DataLoader(VideoDataset(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ]),
        dataset=args.dataset),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        log_test = open('test_not.csv', 'w')
        validate(val_loader, model, criterion, log_test)
        os.remove(log_test)
        return

    if args.lr_scheduler == 'cos_warmup':
        lr_scheduler_clr = CosineAnnealingLR.WarmupCosineLR(
            optimizer=optimizer,
            milestones=[args.warmup, args.epochs],
            warmup_iters=args.warmup,
            min_ratio=1e-7)
    elif args.lr_scheduler == 'lr_step_warmup':
        lr_scheduler_clr = CosineAnnealingLR.WarmupStepLR(
            optimizer=optimizer,
            milestones=[args.warmup] +
            [args.epochs - 30, args.epochs - 10, args.epochs],
            warmup_iters=args.warmup)
    elif args.lr_scheduler == 'lr_step':
        lr_scheduler_clr = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, args.lr_steps, 0.1)
    if args.resume_rgb:
        for epoch in range(0, args.start_epoch):
            optimizer.step()
            lr_scheduler_clr.step()

    log_training = open(
        os.path.join(model_dir, args.root_log, '%s.csv' % args.store_name),
        'a')
    for epoch in range(args.start_epoch, args.epochs):
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch + 1)
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              log_training,
              writer=writer)
        lr_scheduler_clr.step()
        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader,
                             model,
                             criterion,
                             log_training,
                             writer=writer,
                             epoch=epoch)
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'lr': optimizer.param_groups[-1]['lr'],
                }, is_best, model_dir)
            print('best_prec1: {}'.format(best_prec1))
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'lr': optimizer.param_groups[-1]['lr'],
                }, False, model_dir)
            self.avg = self.sum / self.count

    def accuracy(output, target, topk=(1, )):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        "emmanuelle", "RGBFlow")
    num_class = len(categories)

    originalNet = TSN(
        27,
        args.test_segments if args.consensus_type in ['MLP'] else 1,
        "RGBFlow",
        base_model=args.arch,
        consensus_type=args.consensus_type,
        img_feature_dim=args.img_feature_dim,
    )

    torch.save(originalNet, "emmanuelle.pth")
    emmanuelleNet = originalNet.base_model

    print(
Beispiel #8
0
    def __init__(self, 
                num_segments,
                modality,

                lr =0.001, 
                loss_type = 'nll', # cross entropy
                weight_decay=5e-4, #weight_decay: L2 penalty #default
                lr_steps =[30, 60], # epochs to decay learning rate by 10
                momentum= 0.9, 
                gpus= None, 
                clip_gradient =20, 

                new_length=None,
                base_model="resnet50",
                dropout=0.7,
                img_feature_dim=256, #The dimensionality of the features used for relational reasoning. 
                partial_bn=True,
                consensus_type= 'TRN', # MTRN
                
                dataset = 'epic',
                batch_size= 1,
                workers= 2,
                
                resume = None,  #  pretained model (path)
                epochs= None,
                start_epoch = None, #
                
                
                ifprintmodel= 0, # print the model structure
                print_freq =1,
                eval_freq =1,
                ):


        self.num_segments= num_segments
        self.modality= modality
        self.base_model= base_model
        self.new_length= new_length
        self.img_feature_dim= img_feature_dim
        self.consensus_type= consensus_type
        self.dataset= dataset

        self.resume = resume 
        self.epochs = epochs
        self.start_epoch= start_epoch

        self.lr= lr  
        self.loss_type= loss_type
        self.weight_decay = weight_decay
        self.lr_steps= lr_steps
        self.momentum= momentum
        self.partial_bn= partial_bn
        self.dropout= dropout

        self.batch_size = batch_size
        self.workers= workers
        self.gpus=  gpus
        self.eval_freq= eval_freq
        self.print_freq= print_freq

        
        self.num_class, self.train_list, self.val_list, self.root_path, self.prefix = datasets_video.return_dataset(self.dataset, self.modality)
        self.store_name = '_'.join(['TRN', self.dataset, self.modality, self.base_model, self.consensus_type, 'segment%d'% self.num_segments, 'K%d'% self.new_length])      
        self.best_prec1= 0 
        self.clip_gradient= clip_gradient
        
        
        
        self.model = TSN(self.num_class, self.num_segments, self.modality,
                new_length= self.new_length,
                base_model= self.base_model,
                consensus_type= self.consensus_type,
                dropout=self.dropout,
                img_feature_dim= self.img_feature_dim,
                partial_bn= self.partial_bn)

        self.crop_size =  self.model.crop_size 
        self.scale_size = self.model.scale_size
        self.input_mean = self.model.input_mean
        self.input_std = self.model.input_std
        self.model_policies = self.model.get_optim_policies()
        self.augmentation= self.model.get_augmentation()
        
        print('we have {} GPUs found'.format(torch.cuda.device_count()))
        self.model = torch.nn.DataParallel(self.model #, device_ids=self.gpus
                                           ).cuda()

        print(f'''  
+-------------------------------------------------------+
               num_class : {self.num_class}
                modality : {self.modality}
              base_model : {self.base_model}
              new_length : {self.new_length}
          consensus_type : {self.consensus_type}
         img_feature_dim : {self.img_feature_dim}

                  resume : {self.resume}
                  epochs : {self.epochs }
             start_epoch : {self.start_epoch }
                      lr : {self.lr }
               loss_type : {self.loss_type }
            weight_decay : {self.weight_decay }
                lr_steps : {self.lr_steps }
                momentum : {self.momentum }
              partial_bn : {self.partial_bn}
           clip_gradient : {self.clip_gradient }
                 dropout : {self.dropout}

              batch_size : {self.batch_size}
                 workers : {self.workers}
                    gpus : {self.gpus } ( no use now)
               eval_freq : {self.eval_freq }
              print_freq : {self.print_freq }
              
               crop_size : {self.crop_size}
              scale_size : {self.scale_size}
+-------------------------------------------------------+
construct a network named : {self.store_name}''')
        
        #---- checkpoint------load model ---- 
        if self.resume:
            if os.path.isfile(self.resume):
                print(("=> loading checkpoint '{}'".format(self.resume)))
                checkpoint = torch.load(self.resume)
                self.start_epoch = checkpoint['epoch']
                self.best_prec1 = checkpoint['best_prec1']
                self.model.load_state_dict(checkpoint['state_dict'])
                print(("=> loaded checkpoint '{}' (epoch {}) (epochs={})"
                      .format(self.resume, checkpoint['epoch'], self.epochs)))
            else:
                print(("=> no checkpoint found at '{}'".format(self.resume)))

        cudnn.benchmark = True

        # Data loading code
        if self.modality != 'RGBDiff':
            self.normalize = GroupNormalize(self.input_mean, self.input_std)
        else:
            self.normalize = IdentityTransform()
        

        #------- define loss function (criterion) and optimizer-------
        if self.loss_type == 'nll':
            self.criterion = torch.nn.CrossEntropyLoss().cuda()
        else:
            raise ValueError("Unknown loss type")

        
        #---------=========describe parameters:========= ----------------------
        
        print('*'*20,'TSN parameters:')
        Tools.parameter_desc(self.model, ifprint= ifprintmodel)
        #------parameter  way2-----
        print('-'*30)
        for group in self.model_policies:
            print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
                group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
        
        toal_params=0 
        for p in self.model_policies:
    #        print('-'*10,'{} ( num: {})'.format(p['name'],len(p['params'])))
            for i, param in enumerate(p['params']):
                toal_params+= param.size().numel()
    #            if i< 5 :
    #                print(param.size(), param.size().numel())
    #            elif i==5 : 
    #                print('...')
        print('*'*20, 'count from policies, total parameters: {:,}'.format(toal_params))
        print('TRN initialised \n')
Beispiel #9
0
    args.second_model_path = os.path.join('model', args.second_model_path)
    if args.dataset == 'ucf101':
        num_class = 101
        args.train_list = '../temporal-segment-networks/data/ucf101_rgb_train_split_1.txt'
        args.val_list = '../temporal-segment-networks/data/ucf101_rgb_val_split_1.txt'
        args.root_path = '/'
        with open('video_datasets/ucf101/classInd.txt', 'r') as f:
            content = f.readlines()
        class_to_name = {
            int(line.strip().split(' ')[0]) - 1: line.strip().split(' ')[1]
            for line in content
        }
        prefix = 'image_{:05d}.jpg'
    else:
        categories, args.train_list, args.val_list, args.root_path, prefix = \
            datasets_video.return_dataset(args.dataset, args.modality)
        class_to_name = {
            i: name.replace(' ', '-')
            for i, name in enumerate(categories)
        }
        num_class = len(categories)

    print(class_to_name)
    # input('...')
    first_model = TSN(num_class,
                      args.num_segments,
                      args.modality,
                      base_model=args.arch,
                      consensus_type=args.consensus_type,
                      dropout=0.8,
                      img_feature_dim=args.img_feature_dim)
Beispiel #10
0
def main():
    logger.auto_set_dir()

    global args, best_prec1

    import argparse
    parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
    parser.add_argument('--dataset', type=str,default="something", choices=['something', 'jester', 'moments'])
    parser.add_argument('--modality', type=str, default="RGB", choices=['RGB', 'Flow'])
    parser.add_argument('--train_list', type=str, default="")
    parser.add_argument('--val_list', type=str, default="")
    parser.add_argument('--root_path', type=str, default="")
    parser.add_argument('--store_name', type=str, default="")
    # ========================= Model Configs ==========================
    parser.add_argument('--arch', type=str, default="BNInception")
    parser.add_argument('--num_segments', type=int, default=3)
    parser.add_argument('--consensus_type', type=str, default='avg')
    parser.add_argument('--k', type=int, default=3)

    parser.add_argument('--dropout', '--do', default=0.8, type=float,
                        metavar='DO', help='dropout ratio (default: 0.5)')
    parser.add_argument('--loss_type', type=str, default="nll",
                        choices=['nll'])
    parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")

    # ========================= Learning Configs ==========================
    parser.add_argument('--epochs', default=120, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int,
                        metavar='N', help='mini-batch size (default: 256)')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
                        metavar='LRSteps', help='epochs to decay learning rate by 10')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 5e-4)')
    parser.add_argument('--clip-gradient', '--gd', default=20, type=float,
                        metavar='W', help='gradient norm clipping (default: disabled)')
    parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")

    # ========================= Monitor Configs ==========================
    parser.add_argument('--print-freq', '-p', default=20, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--eval-freq', '-ef', default=5, type=int,
                        metavar='N', help='evaluation frequency (default: 5)')

    # ========================= Runtime Configs ==========================
    parser.add_argument('-j', '--workers', default=30, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--snapshot_pref', type=str, default="")
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--gpu', type=str, default='4')
    parser.add_argument('--flow_prefix', default="", type=str)
    parser.add_argument('--root_log', type=str, default='log')
    parser.add_argument('--root_model', type=str, default='model')
    parser.add_argument('--root_output', type=str, default='output')

    args = parser.parse_args()

    args.consensus_type = "TRN"
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device_ids = [int(id) for id in args.gpu.split(',')]
    assert len(device_ids) >1, "TRN must run with GPU_num > 1"

    args.root_log = logger.get_logger_dir()
    args.root_model = logger.get_logger_dir()
    args.root_output = logger.get_logger_dir()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
    num_class = len(categories)


    args.store_name = '_'.join(['TRN', args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)#TODO, , device_ids=[int(id) for id in args.gpu.split(',')]

    if torch.cuda.is_available():
       model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        logger.info('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
Beispiel #11
0
def main():
    check_rootfolders()
    global best_prec1
    if args.run_for == 'train':
        categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
            args.dataset, args.modality)
    elif args.run_for == 'test':
        categories, args.test_list, args.root_path, prefix = datasets_video.return_data(
            args.dataset, args.modality)

    num_class = len(categories)

    args.store_name = '_'.join([
        'STModeling', args.dataset, args.modality, args.arch,
        args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class, args)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            # best_prec1 = 0
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    #print(model)
    cudnn.benchmark = True

    # Data loading code
    if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')):
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    elif args.modality == 'RGBFlow':
        data_length = args.num_motion

    if args.run_for == 'train':
        train_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=False)

        val_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.val_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)

        # define loss function (criterion) and optimizer
        if args.loss_type == 'nll':
            criterion = torch.nn.CrossEntropyLoss().cuda()
        else:
            raise ValueError("Unknown loss type")

        for group in policies:
            print(
                ('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
                    group['name'], len(group['params']), group['lr_mult'],
                    group['decay_mult'])))

        optimizer = torch.optim.SGD(policies,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        if args.consensus_type == 'DNDF':
            params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=1e-5)

        if args.evaluate:
            validate(val_loader, model, criterion, 0)
            return

        log_training = open(
            os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
        for epoch in range(args.start_epoch, args.epochs):
            if not args.consensus_type == 'DNDF':
                adjust_learning_rate(optimizer, epoch, args.lr_steps)

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch,
                  log_training)

            # evaluate on validation set
            if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
                prec1 = validate(val_loader, model, criterion,
                                 (epoch + 1) * len(train_loader), log_training)

                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best)

    elif args.run_for == 'test':
        print("=> loading checkpoint '{}'".format(args.root_weights))
        checkpoint = torch.load(args.root_weights)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda().eval()
        print("=> loaded checkpoint ")

        test_loader = torch.utils.data.DataLoader(TSNDataSet(
            "/home/machine/PROJECTS/OTHER/DATASETS/kussaster/data",
            args.test_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                      isRGBFlow=(args.modality == 'RGBFlow')),
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=False)

        # cam = cv2.VideoCapture(0)
        # cam.set(cv2.CAP_PROP_FPS, 48)

        # for i, (input, _) in enumerate(test_loader):
        #     with torch.no_grad():
        #         input_var = torch.autograd.Variable(input)
        #
        # ret, frame = cam.read()
        # frame_map = np.full((280, 640, 3), 0, np.uint8)
        # frame_map = frame
        # print(frame_map)
        # while (True):
        #     bg = np.full((480, 1200, 3), 15, np.uint8)
        #     bg[:480, :640] = frame
        #
        #     font = cv2.FONT_HERSHEY_SIMPLEX
        #     # cv2.rectangle(bg, (128, 48), (640 - 128, 480 - 48), (0, 255, 0), 3)
        #
        #     cv2.imshow('preview', bg)
        #
        #     if cv2.waitKey(1) & 0xFF == ord('q'):
        #         break

        test(test_loader, model, categories)
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    args.dataset = "thumos"
    args.modality = "RGB"

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        'thumos', args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)

    ###############################################################################
    # ALL LINES AFTER THIS REPRESENT NEW CODE WRITTEN TO TRAIN THE FEW-SHOT MODEL #
    ###############################################################################

    for i in range(10):
        print("TRAINING FEW-SHOT MODEL")

    num_fs_class = 14  # number of few shot classes
    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        'thumos-fs', args.modality)  # load few-shot dataset

    # modify the fully connected layers to fit our new task with 14 classes
    fs_model = model
    fs_model.module.consensus.classifier = nn.Sequential(
        nn.ReLU(), nn.Linear(in_features=768, out_features=512, bias=True),
        nn.ReLU(),
        nn.Linear(in_features=512, out_features=num_fs_class,
                  bias=True)).cuda()

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, fs_model, criterion, 0)
        return

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    args.store_name = '_'.join([
        'fs_TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    best_prec1 = 0
    log_fs_training = open(
        os.path.join(args.root_log, '%s.csv' % "fs-logging"), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        torch.cuda.empty_cache()

        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, fs_model, criterion, optimizer, epoch,
              log_fs_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, fs_model, criterion,
                             (epoch + 1) * len(train_loader), log_fs_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': fs_model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename='fs_checkpoint.pth.tar')
Beispiel #13
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)
    print("num_class: " + str(num_class))

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'Flow' or args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    print("Creating val_loader:")
    print("args.root_path: " + str(args.root_path))
    print("args.val_list: " + str(args.val_list))
    print("args.num_segments: " + str(args.num_segments))
    print("data_length: " + str(data_length))
    print("modality: " + str(args.modality))
    print("prefix: " + str(prefix))
    print("scale_size: " + str(int(scale_size)))
    print("crop_size: " + str(crop_size))
    print("args.arch: " + str(args.arch))
    print("args.batch_size: " + str(args.batch_size))
    print("args.workers: " + str(args.workers))
    print("")

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    git_log_output = subprocess.run(
        [
            'git', 'log', '-n1',
            '--pretty=format:commit: %h%nauthor: %an%n%s%n%b'
        ],
        stdout=subprocess.PIPE).stdout.decode('utf-8').split('\n')
    git_diff_output = subprocess.run(
        ['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')

    if args.exp_name == '':
        exp_name_match = re.match(r'experiment: *(.+)', git_log_output[2])
        if exp_name_match is None:
            print(
                'Experiment name required:\n'
                '  current commit subject does not specify an experiment, and\n'
                '  --experiment_name was not specified')
            sys.exit(0)
        args.exp_name = exp_name_match.group(1)
    print(f'experiment name: {args.exp_name}')

    time = str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
    exp_dir_path = os.path.join(args.root_log, args.exp_name, time)
    log_file_path = os.path.join(exp_dir_path, f'{args.store_name}.csv')
    print("log_file_path:")
    print(log_file_path)
    os.makedirs(exp_dir_path)
    log_training = open(log_file_path, 'w')
    # store information about git status
    git_info_path = os.path.join(exp_dir_path, 'experiment_info.txt')
    with open(git_info_path, 'w') as f:
        f.write('\n'.join(git_log_output))
        f.write('\n\n' + ('=' * 80) + '\n')
        f.write(git_diff_output)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, time)
Beispiel #14
0
def main():

    #*************************Processing Data**************************
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    # 对Something-something数据集进行预处理,将.txt文件读入内存
    categories, train_list, val_list, root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.root_path)
    num_class = len(categories)

    if args.dataset == 'somethingv1' or args.dataset == 'somethingv2':
        # label transformation for left/right categories
        # please refer to labels.json file in sometingv2 for detail.
        target_transforms = {
            86: 87,
            87: 86,
            93: 94,
            94: 93,
            166: 167,
            167: 166
        }
    else:
        target_transforms = None

    #****************************Create Model***************************
    model = getattr(CSN, args.arch)(num_class,
                                    target_transforms=target_transforms,
                                    mode=args.mode)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = get_optim_policies(model)
    train_augmentation = model.get_augmentation()

    # ***************************Data loading code****************************
    normalize = GroupNormalize(input_mean, input_std)

    train_loader = torch.utils.data.DataLoader(VideoDataSet(
        root_path,
        train_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(VideoDataSet(
        root_path,
        val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    #**************************Training config**************************
    device = 'cuda'
    if torch.cuda.is_available():
        devices = ['cuda:' + id for id in args.gpus.split(',')]
        if len(devices) > 1:
            model = torch.nn.DataParallel(model,
                                          device_ids=devices)  # 使用单机多卡进行训练
    else:
        device = 'cpu'
    model = model.to(device)

    if args.resume:  # 用于中断训练后继续训练
        if os.path.isfile(args.resume):  # 用指定的检查点进行训练
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)

            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()  # 交叉熵损失

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    #******************************Training**********************************
    if args.evaluate:
        prec1 = validate(val_loader, model, criterion, 0)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': args.start_epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename='evaluate')

        return

    # 模型存储的名字
    global store_name
    store_name = '_'.join([
        args.type, args.dataset, args.arch,
        'segment%d' % args.num_segments, args.store_name
    ])
    log('storing name: ' + store_name, file=log_stream)

    for epoch in range(args.start_epoch, args.epochs):
        log("********************************\n", file=log_stream)
        log("EPOCH:" + str(epoch + 1) + "\n", file=log_stream)
        # adjust learning rate
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=str(epoch + 1))

        log("********************************\n", file=log_stream)
Beispiel #15
0
def get_pred(video_path, caption_path, opt):
    # options
    parser = argparse.ArgumentParser(
        description="TRN testing on the full validation set")
    # parser.add_argument('dataset', type=str, choices=['something','jester','moments','charades'])
    # parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff'])

    parser.add_argument('--dataset', type=str, default='somethingv2')
    parser.add_argument('--modality', type=str, default='RGB')

    parser.add_argument(
        '--weights',
        type=str,
        default=
        'model/TRN_somethingv2_RGB_BNInception_TRNmultiscale_segment8_best.pth.tar'
    )
    parser.add_argument('--arch', type=str, default="BNInception")
    parser.add_argument('--save_scores', type=str, default=None)
    parser.add_argument('--test_segments', type=int, default=8)
    parser.add_argument('--max_num', type=int, default=-1)
    parser.add_argument('--test_crops', type=int, default=10)
    parser.add_argument('--input_size', type=int, default=224)
    parser.add_argument('--crop_fusion_type',
                        type=str,
                        default='TRNmultiscale',
                        choices=['avg', 'TRN', 'TRNmultiscale'])
    parser.add_argument('-j',
                        '--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--gpus', nargs='+', type=int, default=None)
    parser.add_argument('--img_feature_dim', type=int, default=256)
    parser.add_argument(
        '--num_set_segments',
        type=int,
        default=1,
        help='TODO: select multiply set of n-frames from a video')
    parser.add_argument('--softmax', type=int, default=0)

    args = parser.parse_args()

    def accuracy(output, target, topk=(1, )):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)
        prob, pred = output.topk(maxk, 1, True, True)
        prob = prob.t().data.numpy().squeeze()
        pred = pred.t().data.numpy().squeeze()
        return prob, pred

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality, opt)
    num_class = len(categories)

    net = TSN(num_class,
              args.test_segments
              if args.crop_fusion_type in ['TRN', 'TRNmultiscale'] else 1,
              args.modality,
              base_model=args.arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              opt=opt)

    try:
        checkpoint = torch.load(args.weights)
    except:
        args.weights = os.path.join(opt.project_root, 'scripts/Eval/',
                                    args.weights)
        checkpoint = torch.load(args.weights)

    print("model epoch {} best prec@1: {}".format(checkpoint['epoch'],
                                                  checkpoint['best_prec1']))

    base_dict = {
        '.'.join(k.split('.')[1:]): v
        for k, v in list(checkpoint['state_dict'].items())
    }
    net.load_state_dict(base_dict)

    if args.test_crops == 1:
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(net.input_size),
        ])
    elif args.test_crops == 10:
        cropping = torchvision.transforms.Compose(
            [GroupOverSample(net.input_size, net.scale_size)])
    else:
        raise ValueError(
            "Only 1 and 10 crops are supported while we got {}".format(
                args.test_crops))

    data_loader = torch.utils.data.DataLoader(TSNDataSet(
        video_path,
        caption_path,
        num_segments=args.test_segments,
        new_length=1 if args.modality == "RGB" else 5,
        modality=args.modality,
        image_tmpl=prefix,
        test_mode=True,
        transform=torchvision.transforms.Compose([
            cropping,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            GroupNormalize(net.input_mean, net.input_std),
        ])),
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers * 2,
                                              pin_memory=True)

    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))

    #net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
    net = torch.nn.DataParallel(net.cuda())
    net.eval()

    data_gen = enumerate(data_loader)

    output = []

    def eval_video(video_data):
        i, data, label = video_data
        num_crop = args.test_crops

        if args.modality == 'RGB':
            length = 3
        elif args.modality == 'Flow':
            length = 10
        elif args.modality == 'RGBDiff':
            length = 18
        else:
            raise ValueError("Unknown modality " + args.modality)

        input_var = torch.autograd.Variable(data.view(-1, length, data.size(2),
                                                      data.size(3)),
                                            volatile=True)
        rst = net(input_var)
        if args.softmax == 1:
            # take the softmax to normalize the output to probability
            rst = F.softmax(rst)

        rst = rst.data.cpu().numpy().copy()

        if args.crop_fusion_type in ['TRN', 'TRNmultiscale']:
            rst = rst.reshape(-1, 1, num_class)
        else:
            rst = rst.reshape((num_crop, args.test_segments,
                               num_class)).mean(axis=0).reshape(
                                   (args.test_segments, 1, num_class))

        return i, rst, label[0]

    max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)

    prob_all, pred_all = [], []
    for i, (data, label) in data_gen:
        if i >= max_num:
            break
        rst = eval_video((i, data, label))
        output.append(rst[1:])
        prob, pred = accuracy(torch.from_numpy(np.mean(rst[1], axis=0)),
                              label,
                              topk=(1, 174))
        prob_all.append(prob)
        pred_all.append(pred)
    return prob_all, pred_all
Beispiel #16
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, train_list, val_list, root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.root_path)
    num_class = len(categories)

    global store_name
    store_name = '_'.join([
        args.type, args.dataset, args.arch,
        'segment%d' % args.num_segments, args.store_name
    ])
    print(('storing name: ' + store_name))

    if args.dataset == 'somethingv1' or args.dataset == 'somethingv2':
        # label transformation for left/right categories
        # please refer to labels.json file in sometingv2 for detail.
        target_transforms = {
            86: 87,
            87: 86,
            93: 94,
            94: 93,
            166: 167,
            167: 166
        }
    else:
        target_transforms = None

    model = TemporalModel(num_class,
                          args.num_segments,
                          model=args.type,
                          backbone=args.arch,
                          alpha=args.alpha,
                          beta=args.beta,
                          dropout=args.dropout,
                          target_transforms=target_transforms)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = get_optim_policies(model)
    train_augmentation = model.get_augmentation()

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)

            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.module.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)

    train_loader = torch.utils.data.DataLoader(VideoDataSet(
        root_path,
        train_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(VideoDataSet(
        root_path,
        val_list,
        num_segments=args.num_segments,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        prec1 = validate(val_loader, model, criterion, 0)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': args.start_epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)

        return

    log_training = open(
        os.path.join(args.checkpoint_dir, 'log', '%s.csv' % store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        # adjust learning rate
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Beispiel #17
0
def main_charades():
    global args, best_prec1, use_gpu

    use_gpu = torch.cuda.is_available()

    categories, args.train_list, args.val_list, args.train_num_list, args.val_num_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality, args.root_path)
    num_class = len(categories)
    crop_size = args.crop_size
    scale_size = args.scale_size
    input_mean = [0.485, 0.456, 0.406]
    input_std = [0.229, 0.224, 0.225]

    train_dataset = TSNDataSet(
        args.root_path,
        args.train_list,
        args.train_num_list,
        num_class=num_class,
        num_segments=args.num_segments,
        new_length=args.data_length,
        modality=args.modality,
        image_tmpl=prefix,
        transform=torchvision.transforms.Compose([
            GroupMultiScaleCrop(crop_size, [1.0, 0.875, 0.75, 0.66, 0.5],
                                max_distort=2),
            GroupRandomHorizontalFlip(is_flow=False),
            Stack(roll=False),
            ToTorchFormatTensor(div=True),
            GroupNormalize(input_mean, input_std),
            ChangeToCTHW(modality=args.modality)
        ]))

    val_dataset = TSNDataSet(args.root_path,
                             args.val_list,
                             args.val_num_list,
                             num_class=num_class,
                             num_segments=args.num_segments,
                             new_length=args.data_length,
                             modality=args.modality,
                             image_tmpl=prefix,
                             random_shift=False,
                             transform=torchvision.transforms.Compose([
                                 GroupScale(int(scale_size)),
                                 GroupCenterCrop(crop_size),
                                 Stack(roll=False),
                                 ToTorchFormatTensor(div=True),
                                 GroupNormalize(input_mean, input_std),
                                 ChangeToCTHW(modality=args.modality)
                             ]))

    # model = modelfile.InceptionI3d(num_class, in_channels=3)
    model = modelfile.gcn_i3d(
        num_class=num_class,
        t=0.4,
        adj_file=
        './data/Charades_v1/gcn_info/class_graph_conceptnet_context_0.8.pkl',
        word_file='./data/Charades_v1/gcn_info/class_word.pkl')

    # define loss function (criterion)
    criterion = nn.MultiLabelSoftMarginLoss()

    # define optimizer
    params = get_config_optim(model,
                              lr=args.lr,
                              weight_decay=args.weight_decay)
    # params = get_optim_fix_conv(model, lr=args.lr, weight_decay=args.weight_decay)

    # optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    optimizer = torch.optim.Adam(params, eps=1e-8)

    state = {
        'batch_size': args.batch_size,
        'val_batch_size': args.val_batch_size,
        'image_size': args.image_size,
        'max_epochs': args.epochs,
        'evaluate': args.evaluate,
        'resume': args.resume,
        'num_classes': num_class
    }
    state['difficult_examples'] = False
    state['print_freq'] = args.print_freq
    state['save_model_path'] = args.save_model_path
    state['log_path'] = args.log_path
    state['logname'] = args.logname
    state['workers'] = args.workers
    state['epoch_step'] = args.epoch_step
    state['lr'] = args.lr
    state['device_ids'] = list(range(torch.cuda.device_count()))
    if args.evaluate:
        state['evaluate'] = True
    mapengine = engine.GCNMultiLabelMAPEngine(
        state, inp_file='./data/Charades_v1/gcn_info/class_word.pkl')
    mapengine.learning(model, criterion, train_dataset, val_dataset, optimizer)
Beispiel #18
0
def accuracy(output, target, topk=(1, )):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
    args.dataset)

if args.dataset == 'something-v1':
    num_class = 174
    args.rgb_prefix = ''
    rgb_read_format = "{:05d}.jpg"
elif args.dataset == 'diving48':
    num_class = 48
    args.rgb_prefix = 'frames'
    rgb_read_format = "{:05d}.jpg"
else:
    raise ValueError('Unknown dataset ' + args.dataset)

net = VideoModel(num_class=num_class,
                 num_segments=args.test_segments,
                 modality=args.modality,
Beispiel #19
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'STSNN', args.dataset, args.modality, args.arch,
        'group%d' % args.num_segments,
        '%df1c' % args.num_motion
    ])
    print('storing name: ' + args.store_name)

    model = STSNN(num_class,
                  args.num_segments,
                  args.modality,
                  base_model=args.arch,
                  consensus_type=args.consensus_type,
                  dropout=args.dropout,
                  num_motion=args.num_motion,
                  img_feature_dim=args.img_feature_dim,
                  partial_bn=not args.no_partialbn,
                  dataset=args.dataset)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    print(model)
    cudnn.benchmark = True

    # Data loading code
    if ((args.modality != 'RGBDiff') | (args.modality != 'RGBFlow')):
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5
    elif args.modality == 'RGBFlow':
        data_length = args.num_motion

    train_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        dataset=args.dataset,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                  isRGBFlow=(args.modality == 'RGBFlow')),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        dataset=args.dataset,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3']),
                  isRGBFlow=(args.modality == 'RGBFlow')),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=False)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Beispiel #20
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.test_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset, args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        'TRN', args.dataset, args.modality, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = TSN(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    # Four types of input modalities for two-stream ConvNets (one stream spatial and the other temporal): a single RGB image, stacked RGB difference,
    # stacked optical flow field, and stacked warped optical flow field;  the spatial stream ConvNet operates on a single RGB images,
    # and the temporal stream ConvNet takes a stack of consecutive optical flow fields as input.
    # A single RGB image usually encodes static appearance at a specific time point and lacks the contextual information about previous and next frames.
    # RGB difference between two consecutive frames describe the appearance change, which may correspond to the motion salient region.
    # Optical flow fields may not concentrate on the human action; the warped optical flow suppresses the background motion and makes motion concentrate
    # on the actor.

    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

# Division between train and val set

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            new_length=data_length,
            modality=args.modality,
            image_tmpl=prefix,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(
                    roll=(args.arch in ['BNInception', 'InceptionV3'])
                ),  # Batch-Normalization-Inception, InceptionV3: evolution of InceptionV2 of GoogleNet
                ToTorchFormatTensor(
                    div=(args.arch not in ['BNInception', 'InceptionV3'])),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(TSNDataSet(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=prefix,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Beispiel #21
0
def main():
    finetuning = False

    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    if args.dataset == 'something-v1':
        num_class = 174
        args.rgb_prefix = ''
        rgb_read_format = "{:05d}.jpg"
    elif args.dataset == 'diving48':
        num_class = 48
        args.rgb_prefix = 'frames'
        rgb_read_format = "{:05d}.jpg"
    else:
        raise ValueError('Unknown dataset ' + args.dataset)

    model_dir = os.path.join('experiments', args.dataset, args.arch,
                             args.consensus_type + '-' + args.modality,
                             str(args.run_iter))
    if not args.resume:
        if os.path.exists(model_dir):
            print('Dir {} exists!!!'.format(model_dir))
            sys.exit()
        else:
            os.makedirs(model_dir)
            os.makedirs(os.path.join(model_dir, args.root_log))

    writer = SummaryWriter(model_dir)

    args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
        args.dataset)

    if 'something' in args.dataset:
        # label transformation for left/right categories
        target_transforms = {
            86: 87,
            87: 86,
            93: 94,
            94: 93,
            166: 167,
            167: 166
        }
        print('Target transformation is enabled....')
    else:
        target_transforms = None

    args.store_name = '_'.join([
        args.dataset, args.arch, args.consensus_type,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = VideoModel(num_class=num_class,
                       modality=args.modality,
                       num_segments=args.num_segments,
                       base_model=args.arch,
                       consensus_type=args.consensus_type,
                       dropout=args.dropout,
                       partial_bn=not args.no_partialbn,
                       gsm=args.gsm,
                       target_transform=target_transforms)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(VideoDataset(
        args.root_path,
        args.train_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix + rgb_read_format,
        transform=torchvision.transforms.Compose([
            train_augmentation,
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(VideoDataset(
        args.root_path,
        args.val_list,
        num_segments=args.num_segments,
        new_length=data_length,
        modality=args.modality,
        image_tmpl=args.rgb_prefix + rgb_read_format,
        random_shift=False,
        transform=torchvision.transforms.Compose([
            GroupScale(int(scale_size)),
            GroupCenterCrop(crop_size),
            Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
            ToTorchFormatTensor(
                div=(args.arch not in ['BNInception', 'InceptionV3'])),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler_clr = CosineAnnealingLR.WarmupCosineLR(
        optimizer=optimizer,
        milestones=[args.warmup, args.epochs],
        warmup_iters=args.warmup,
        min_ratio=1e-7)
    if args.resume:
        for epoch in range(0, args.start_epoch):
            lr_scheduler_clr.step()

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    log_training = open(
        os.path.join(model_dir, args.root_log, '%s.csv' % args.store_name),
        'a')
    for epoch in range(args.start_epoch, args.epochs):

        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch + 1)

        train_prec1 = train(train_loader,
                            model,
                            criterion,
                            optimizer,
                            epoch,
                            log_training,
                            writer=writer)

        lr_scheduler_clr.step()

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader,
                             model,
                             criterion, (epoch + 1) * len(train_loader),
                             log_training,
                             writer=writer,
                             epoch=epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'current_prec1': prec1,
                    'lr': optimizer.param_groups[-1]['lr'],
                }, is_best, model_dir)
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'current_prec1': train_prec1,
                    'lr': optimizer.param_groups[-1]['lr'],
                }, False, model_dir)
Beispiel #22
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(args.dataset, args.modality)
    num_class = len(categories)


    args.store_name = '_'.join(['TRN', args.dataset, args.modality, args.arch, args.consensus_type, 'segment%d'% args.num_segments])
    print('storing name: ' + args.store_name)

    model = TSN(2, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn)

    checkpoint = torch.load('pretrain/TRN_somethingv2_RGB_BNInception_TRNmultiscale_segment8_best.pth.tar', map_location='cpu')
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
    for key in ['consensus.fc_fusion_scales.6.3.bias', 'consensus.fc_fusion_scales.5.3.bias',
                'consensus.fc_fusion_scales.4.3.bias',
                'consensus.fc_fusion_scales.3.3.bias', 'consensus.fc_fusion_scales.2.3.bias',
                'consensus.fc_fusion_scales.1.3.bias',
                'consensus.fc_fusion_scales.0.3.bias', 'consensus.fc_fusion_scales.6.3.weight',
                'consensus.fc_fusion_scales.5.3.weight',
                'consensus.fc_fusion_scales.4.3.weight', 'consensus.fc_fusion_scales.3.3.weight',
                'consensus.fc_fusion_scales.2.3.weight',
                'consensus.fc_fusion_scales.1.3.weight', 'consensus.fc_fusion_scales.0.3.weight']:
        del base_dict[key]
    # print(base_dict)
    model.load_state_dict(base_dict, strict=False)
    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    policies = model.get_optim_policies()
    train_augmentation = model.get_augmentation()

    model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    cudnn.benchmark = True

    # Data loading code
    if args.modality != 'RGBDiff':
        normalize = GroupNormalize(input_mean, input_std)
    else:
        normalize = IdentityTransform()

    if args.modality == 'RGB':
        data_length = 1
    elif args.modality in ['Flow', 'RGBDiff']:
        data_length = 5

    train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
                       normalize,
                   ])),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    # val_loader = torch.utils.data.DataLoader(
    #     TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
    #                new_length=data_length,
    #                modality=args.modality,
    #                image_tmpl=prefix,
    #                random_shift=False,
    #                transform=torchvision.transforms.Compose([
    #                    GroupScale(int(scale_size)),
    #                    GroupCenterCrop(crop_size),
    #                    Stack(roll=(args.arch in ['BNInception','InceptionV3'])),
    #                    ToTorchFormatTensor(div=(args.arch not in ['BNInception','InceptionV3'])),
    #                    normalize,
    #                ])),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and optimizer
    if args.loss_type == 'nll':
        weight = torch.ones([2]).cuda()
        weight[0] = 1.2
        pos_weight = torch.ones([2]).cuda()
        #pos_weight[0] = 2
        criterion = torch.nn.BCEWithLogitsLoss(weight = weight, pos_weight=pos_weight).cuda() 
        #criterion = torch.nn.CrossEntropyLoss().cuda()
        
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
        
    optimizer = torch.optim.SGD(policies,
                                0.0001,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # if args.evaluate:
    #     validate(val_loader, model, criterion, 0)
    #     return

    log_training = open(os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)
        torch.save(model.state_dict(), 'checkpoint_bce_20_w12_{}.pth.tar'.format(epoch))
        torch.save(model.state_dict(), 'checkpoint_bce_20_w12_{}.pth'.format(epoch))
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)
Beispiel #23
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    check_rootfolders()

    if not args.test:
        categories, args.train_list, args.val_list, args.root_path, prefix = datasets_video.return_dataset(
            args.dataset, args.modality)
    else:
        categories, args.test_list, args.root_path, prefix = datasets_video.return_dataset(
            'SHGDTuples', args.modality)
    num_class = len(categories)

    args.store_name = '_'.join([
        args.dataset, args.modality, args.arch,
        'segment%d' % args.num_segments
    ])
    print('storing name: ' + args.store_name)

    model = MFF(num_class,
                args.num_segments,
                args.modality,
                base_model=args.arch,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                dataset=args.dataset)

    crop_size = model.crop_size
    scale_size = model.scale_size
    input_mean = model.input_mean
    input_std = model.input_std
    train_augmentation = model.get_augmentation()

    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.test, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.pretrained:
        if args.arch == 'squeezenet1_1':
            name = 'module.base_model.0.weight'
        else:
            name = 'module.base_model.features.0.0.weight'

        if os.path.isfile(args.pretrained):
            pretrained_dict = torch.load(args.pretrained)
            pretrained_state_dict = pretrained_dict['state_dict']
            #for name, param in pretrained_state_dict.items():
            #if 'base_model' in name:
            #print(name)
            pretrained_state_dict = {
                k: v
                for k, v in pretrained_state_dict.items()
                if 'module.consensus.classifier.3.' not in k
            }
            model_dict = model.state_dict()
            weight_conv_t = pretrained_state_dict[name]

        else:
            print(("=> no pretrained model checkpoint found at '{}'".format(
                args.pretrained)))

        if args.modality == 'IRD':
            # make the first conv from 3 chann to 2 chann (average the sum of 3 chann)
            weight_conv_t = weight_conv_t.sum(1)
            weight_conv_t = weight_conv_t.unsqueeze(1)
            weight_conv_t = weight_conv_t.mean(1)
            weight_conv_t = torch.stack((weight_conv_t, weight_conv_t), 1)
            pretrained_state_dict[name] = weight_conv_t
            model_dict.update(pretrained_state_dict)
            print("Converted the first conv layer to 2 channels.")

        if args.modality == 'IR' or args.modality == 'D':
            # make the first conv from 3 chann to 1 chann (average the sum of 3 chann)
            weight_conv_t = weight_conv_t.sum(1)
            weight_conv_t = weight_conv_t.unsqueeze(1)
            weight_conv_t = weight_conv_t.mean(1)
            weight_conv_t = weight_conv_t.unsqueeze(1)
            pretrained_state_dict[name] = weight_conv_t
            model_dict.update(pretrained_state_dict)
            print("Converted the first conv layer to 1 channel.")

        model.load_state_dict(model_dict)
        print("=> loaded pretrained model checkpoint '{}'".format(
            args.pretrained))
    print(model)

    ## to print the number of trainable paramters in the network
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params_num = sum([np.prod(p.size()) for p in model_parameters])
    print('Total number of parameters:' + str(params_num))
    cudnn.benchmark = True

    # Data loading code
    normalize = GroupNormalize(input_mean, input_std)
    if args.dataset == 'SHGD':
        from SHGD import DataSet
    if args.dataset == 'jester':
        from Jester import DataSet
    if not args.test:
        train_loader = torch.utils.data.DataLoader(DataSet(
            args.root_path,
            args.train_list,
            num_segments=args.num_segments,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            transform=torchvision.transforms.Compose([
                train_augmentation,
                Stack(),
                ToTorchFormatTensor(),
                normalize,
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=False)

        val_loader = torch.utils.data.DataLoader(DataSet(
            args.root_path,
            args.val_list,
            num_segments=args.num_segments,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(),
                ToTorchFormatTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=False)

    else:
        test_loader = torch.utils.data.DataLoader(DataSet(
            args.root_path,
            args.test_list,
            num_segments=args.num_segments,
            modality=args.modality,
            image_tmpl=prefix,
            dataset=args.dataset,
            random_shift=False,
            test_mode=True,
            transform=torchvision.transforms.Compose([
                GroupScale(int(scale_size)),
                GroupCenterCrop(crop_size),
                Stack(),
                ToTorchFormatTensor(),
                normalize,
            ])),
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=False)

    if args.test:
        if not args.resume:
            print('Please give a path to a trained model for testing.')
            sys.exit()
        else:
            test(args.start_epoch, test_loader, model, args)
            return

    if args.loss_type == 'nll':
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif args.loss_type == 'nll' and num_class == 13:
        # give the "No gesture/Hand up/Hand down less weight than the other classes. No:4420 Hand up:2280 Hand Down:2190 Others:228
        weights = [1, 1, 1, 1 / 10, 1 / 10, 1 / 20, 1, 1, 1, 1, 1, 1, 1]
        class_weights = torch.Tensor(weights).cuda()
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    else:
        raise ValueError("Unknown loss type")

    for group in policies:
        print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            group['name'], len(group['params']), group['lr_mult'],
            group['decay_mult'])))

    optimizer = torch.optim.SGD(policies,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    log_training = open(
        os.path.join(args.root_log, '%s.csv' % args.store_name), 'w')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion,
                             (epoch + 1) * len(train_loader), log_training,
                             num_class)

            #remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
Beispiel #24
0
                images.extend(seg_imgs)
                if p < record.num_frames:
                    p += 1

        process_data = self.transform(images) # transformation changes 8, 224, 244 to 24, 224, 224,

        return process_data, record.label

    def __len__(self):
        return len(self.utterance_list)

if __name__ == "__main__":
    import datasets_video
    from transforms import *
    import torchvision
    root_path, train_dict, val_dict = datasets_video.return_dataset('omg', 'RGB','face')
    label_name = 'arousal'
    num_segments = 8 
     
    dataset = TSNDataSet(root_path, train_dict, label_name, num_segments=num_segments, phase='Train',
                   new_length=1,
                   modality='RGB',
                   image_tmpl='frame_det_00_{:06d}.bmp',
                   transform=torchvision.transforms.Compose([
                                    GroupScale(256),
                                    GroupRandomCrop(224),
                                    Stack(True),
                                    ToTorchFormatTensor(),
                                    GroupNormalize(
                                        mean=[.485, .456, .406],
                                        std=[.229, .224, .225]),