def build_model_continue():
    modelLocation = "./checkpoint/" + args.dataset + "_" + args.arch + "_split" + str(
        args.split)
    model_path = os.path.join(modelLocation, 'model_best.pth.tar')
    params = torch.load(model_path)
    print(modelLocation)
    if args.dataset == 'ucf101':
        model = models.__dict__[args.arch](modelPath='',
                                           num_classes=101,
                                           length=args.num_seg)
    elif args.dataset == 'hmdb51':
        model = models.__dict__[args.arch](modelPath='',
                                           num_classes=51,
                                           length=args.num_seg)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model.load_state_dict(params['state_dict'])
    model = model.cuda()
    optimizer = AdamW(model.parameters(),
                      lr=args.lr,
                      weight_decay=args.weight_decay)
    optimizer.load_state_dict(params['optimizer'])

    startEpoch = params['epoch']
    best_prec = params['best_prec1']
    return model, startEpoch, optimizer, best_prec
def main():
    global args, best_prec1, model, writer, best_loss, length, width, height, input_size, scheduler
    args = parser.parse_args()
    training_continue = args.contine
    if '3D' in args.arch:
        if 'I3D' in args.arch or 'MFNET3D' in args.arch:
            if '112' in args.arch:
                scale = 0.5
            else:
                scale = 1
        else:
            if '224' in args.arch:
                scale = 1
            else:
                scale = 0.5
    elif 'r2plus1d' in args.arch:
        scale = 0.5
    else:
        scale = 1

    print('scale: %.1f' % (scale))

    input_size = int(224 * scale)
    width = int(340 * scale)
    height = int(256 * scale)

    saveLocation = "./checkpoint/" + args.dataset + "_" + args.arch + "_split" + str(
        args.split)
    if not os.path.exists(saveLocation):
        os.makedirs(saveLocation)
    writer = SummaryWriter(saveLocation)

    # create model

    if args.evaluate:
        print("Building validation model ... ")
        model = build_model_validate()
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
    elif training_continue:
        model, startEpoch, optimizer, best_prec1 = build_model_continue()
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
            #param_group['lr'] = lr
        print(
            "Continuing with best precision: %.3f and start epoch %d and lr: %f"
            % (best_prec1, startEpoch, lr))
    else:
        print("Building model with ADAMW... ")
        model = build_model()
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
        startEpoch = 0

    if HALF:
        model.half()  # convert to half precision
        for layer in model.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()

    print("Model %s is loaded. " % (args.arch))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    criterion2 = nn.MSELoss().cuda()

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=5,
                                               verbose=True)

    print("Saving everything to directory %s." % (saveLocation))
    if args.dataset == 'ucf101':
        dataset = './datasets/ucf101_frames'
    elif args.dataset == 'hmdb51':
        dataset = './datasets/hmdb51_frames'
    elif args.dataset == 'smtV2':
        dataset = './datasets/smtV2_frames'
    elif args.dataset == 'window':
        dataset = './datasets/window_frames'
    elif args.dataset == 'haa500_basketball':
        dataset = './datasets/haa500_basketball_frames'
    else:
        print("No convenient dataset entered, exiting....")
        return 0

    cudnn.benchmark = True
    modality = args.arch.split('_')[0]
    if "3D" in args.arch or 'tsm' in args.arch or 'slowfast' in args.arch or 'r2plus1d' in args.arch:
        if '64f' in args.arch:
            length = 64
        elif '32f' in args.arch:
            length = 32
        else:
            length = 16
    else:
        length = 1
    # Data transforming
    if modality == "rgb" or modality == "pose":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        if 'I3D' in args.arch:
            if 'resnet' in args.arch:
                clip_mean = [0.45, 0.45, 0.45] * args.num_seg * length
                clip_std = [0.225, 0.225, 0.225] * args.num_seg * length
            else:
                clip_mean = [0.5, 0.5, 0.5] * args.num_seg * length
                clip_std = [0.5, 0.5, 0.5] * args.num_seg * length
            #clip_std = [0.25, 0.25, 0.25] * args.num_seg * length
        elif 'MFNET3D' in args.arch:
            clip_mean = [0.48627451, 0.45882353, 0.40784314
                         ] * args.num_seg * length
            clip_std = [0.234, 0.234, 0.234] * args.num_seg * length
        elif "3D" in args.arch:
            clip_mean = [114.7748, 107.7354, 99.4750] * args.num_seg * length
            clip_std = [1, 1, 1] * args.num_seg * length
        elif "r2plus1d" in args.arch:
            clip_mean = [0.43216, 0.394666, 0.37645] * args.num_seg * length
            clip_std = [0.22803, 0.22145, 0.216989] * args.num_seg * length
        elif "rep_flow" in args.arch:
            clip_mean = [0.5, 0.5, 0.5] * args.num_seg * length
            clip_std = [0.5, 0.5, 0.5] * args.num_seg * length
        elif "slowfast" in args.arch:
            clip_mean = [0.45, 0.45, 0.45] * args.num_seg * length
            clip_std = [0.225, 0.225, 0.225] * args.num_seg * length
        else:
            clip_mean = [0.485, 0.456, 0.406] * args.num_seg * length
            clip_std = [0.229, 0.224, 0.225] * args.num_seg * length
    elif modality == "pose":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        clip_mean = [0.485, 0.456, 0.406] * args.num_seg
        clip_std = [0.229, 0.224, 0.225] * args.num_seg
    elif modality == "flow":
        is_color = False
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        if 'I3D' in args.arch:
            clip_mean = [0.5, 0.5] * args.num_seg * length
            clip_std = [0.5, 0.5] * args.num_seg * length
        elif "3D" in args.arch:
            clip_mean = [127.5, 127.5] * args.num_seg * length
            clip_std = [1, 1] * args.num_seg * length
        else:
            clip_mean = [0.5, 0.5] * args.num_seg * length
            clip_std = [0.226, 0.226] * args.num_seg * length
    elif modality == "both":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        clip_mean = [0.485, 0.456, 0.406, 0.5, 0.5] * args.num_seg * length
        clip_std = [0.229, 0.224, 0.225, 0.226, 0.226] * args.num_seg * length
    else:
        print("No such modality. Only rgb and flow supported.")

    normalize = video_transforms.Normalize(mean=clip_mean, std=clip_std)

    if "3D" in args.arch and not ('I3D' in args.arch):
        train_transform = video_transforms.Compose([
            video_transforms.MultiScaleCrop((input_size, input_size),
                                            scale_ratios),
            video_transforms.RandomHorizontalFlip(),
            video_transforms.ToTensor2(),
            normalize,
        ])

        val_transform = video_transforms.Compose([
            video_transforms.CenterCrop((input_size)),
            video_transforms.ToTensor2(),
            normalize,
        ])
    else:
        train_transform = video_transforms.Compose([
            video_transforms.MultiScaleCrop((input_size, input_size),
                                            scale_ratios),
            video_transforms.RandomHorizontalFlip(),
            video_transforms.ToTensor(),
            normalize,
        ])

        val_transform = video_transforms.Compose([
            video_transforms.CenterCrop((input_size)),
            video_transforms.ToTensor(),
            normalize,
        ])

    # data loading
    train_setting_file = "train_%s_split%d.txt" % (modality, args.split)
    train_split_file = os.path.join(args.settings, args.dataset,
                                    train_setting_file)
    val_setting_file = "val_%s_split%d.txt" % (modality, args.split)
    val_split_file = os.path.join(args.settings, args.dataset,
                                  val_setting_file)
    if not os.path.exists(train_split_file) or not os.path.exists(
            val_split_file):
        print(
            "No split file exists in %s directory. Preprocess the dataset first"
            % (args.settings))

    train_dataset = datasets.__dict__[args.dataset](
        root=dataset,
        source=train_split_file,
        phase="train",
        modality=modality,
        is_color=is_color,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=train_transform,
        num_segments=args.num_seg)

    val_dataset = datasets.__dict__[args.dataset](
        root=dataset,
        source=val_split_file,
        phase="val",
        modality=modality,
        is_color=is_color,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=val_transform,
        num_segments=args.num_seg)

    print('{} samples found, {} train samples and {} test samples.'.format(
        len(val_dataset) + len(train_dataset), len(train_dataset),
        len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        prec1, prec3, _ = validate(val_loader, model, criterion, criterion2,
                                   modality)
        return

    for epoch in range(startEpoch, args.epochs):
        #        if learning_rate_index > max_learning_rate_decay_count:
        #            break
        #        adjust_learning_rate(optimizer, epoch)
        train(train_loader, model, criterion, criterion2, optimizer, epoch,
              modality)

        # evaluate on validation set
        prec1 = 0.0
        lossClassification = 0
        if (epoch + 1) % args.save_freq == 0:
            prec1, prec3, lossClassification = validate(
                val_loader, model, criterion, criterion2, modality)
            writer.add_scalar('data/top1_validation', prec1, epoch)
            writer.add_scalar('data/top3_validation', prec3, epoch)
            writer.add_scalar('data/classification_loss_validation',
                              lossClassification, epoch)
            scheduler.step(lossClassification)
        # remember best prec@1 and save checkpoint

        is_best = prec1 >= best_prec1
        best_prec1 = max(prec1, best_prec1)
        #        best_in_existing_learning_rate = max(prec1, best_in_existing_learning_rate)
        #
        #        if best_in_existing_learning_rate > prec1 + 1:
        #            learning_rate_index = learning_rate_index
        #            best_in_existing_learning_rate = 0

        if (epoch + 1) % args.save_freq == 0:
            checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
            if is_best:
                print("Model works well")
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, checkpoint_name, saveLocation)

    checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint_name, saveLocation)
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Пример #3
0
def main(args):
    global best_prec1, best_loss

    input_size = int(224 * args.scale)
    width = int(340 * args.scale)
    height = int(256 * args.scale)

    if not os.path.exists(args.savelocation):
        os.makedirs(args.savelocation)

    now = time.time()
    savelocation = os.path.join(args.savelocation, str(now))
    os.makedirs(savelocation)

    logging.basicConfig(filename=os.path.join(savelocation, "log.log"),
                        level=logging.INFO)

    model = build_model(args.arch, args.pre, args.num_seg, args.resume)
    optimizer = AdamW(model.parameters(),
                      lr=args.lr,
                      weight_decay=args.weight_decay)

    criterion = nn.CrossEntropyLoss().cuda()
    criterion2 = nn.MSELoss().cuda()

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=5,
                                               verbose=True)

    # if args.dataset=='sign':
    #     dataset="/data/AUTSL/train_img_c"
    # elif args.dataset=="signd":
    #     dataset="/data/AUTSL/train_img_c"
    # elif args.dataset=="customd":
    #     dataset="/data/AUTSL/train_img_c"
    # else:
    #     print("no dataset")
    #     return 0

    cudnn.benchmark = True
    length = 64

    scale_ratios = [1.0, 0.875, 0.75, 0.66]

    clip_mean = [0.43216, 0.394666, 0.37645] * args.num_seg * length
    clip_std = [0.22803, 0.22145, 0.216989] * args.num_seg * length

    normalize = video_transforms.Normalize(mean=clip_mean, std=clip_std)

    train_transform = video_transforms.Compose([
        video_transforms.CenterCrop(input_size),
        video_transforms.ToTensor2(),
        normalize,
    ])

    val_transform = video_transforms.Compose([
        video_transforms.CenterCrop((input_size)),
        video_transforms.ToTensor2(),
        normalize,
    ])

    # test_transform = video_transforms.Compose([
    #     video_transforms.CenterCrop((input_size)),
    #     video_transforms.ToTensor2(),
    #     normalize,
    # ])
    # test_file = os.path.join(args.datasetpath, args.testlist)

    if not os.path.exists(args.trainlist) or not os.path.exists(args.vallist):
        print(
            "No split file exists in %s directory. Preprocess the dataset first"
            % (args.datasetpath))

    train_dataset = datasets.__dict__[args.dataset](
        root=args.datasetpath,
        source=args.trainlist,
        phase="train",
        modality="rgb",
        is_color=True,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=train_transform,
        num_segments=args.num_seg)

    val_dataset = datasets.__dict__[args.dataset](
        root=args.datasetpath,
        source=args.vallist,
        phase="val",
        modality="rgb",
        is_color=True,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=val_transform,
        num_segments=args.num_seg)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    best_prec1 = 0
    for epoch in range(0, args.epochs):
        train(length, input_size, train_loader, model, criterion, criterion2,
              optimizer, epoch)

        if (epoch + 1) % args.save_freq == 0:
            is_best = False
            prec1, prec3, lossClassification = validate(
                length, input_size, val_loader, model, criterion, criterion2)
            scheduler.step(lossClassification)

            if prec1 >= best_prec1:
                is_best = True
                best_prec1 = prec1

            checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
            text = "save checkpoint {}".format(checkpoint_name)
            print(text)
            logging.info(text)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "prec1": prec1,
                    "optimizer": optimizer.state_dict()
                }, is_best, checkpoint_name, savelocation)