def main():
    global my_dict, keys, k_len, arr, xxx, args, log_file, best_prec1

    parser = argparse.ArgumentParser(
        description='Single Shot MultiBox Detector Training')
    parser.add_argument('--version',
                        default='v2',
                        help='conv11_2(v2) or pool6(v1) as last layer')
    parser.add_argument('--basenet',
                        default='vgg16_reducedfc.pth',
                        help='pretrained base model')
    parser.add_argument('--dataset',
                        default='ucf24',
                        help='pretrained base model')
    parser.add_argument('--ssd_dim',
                        default=300,
                        type=int,
                        help='Input Size for SSD')  # only support 300 now
    parser.add_argument(
        '--modality',
        default='rgb',
        type=str,
        help='INput tyep default rgb options are [rgb,brox,fastOF]')
    parser.add_argument('--jaccard_threshold',
                        default=0.5,
                        type=float,
                        help='Min Jaccard index for matching')
    parser.add_argument('--batch_size',
                        default=40,
                        type=int,
                        help='Batch size for training')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='Number of workers used in dataloading')
    parser.add_argument('--max_iter',
                        default=120000,
                        type=int,
                        help='Number of training iterations')
    parser.add_argument('--man_seed',
                        default=123,
                        type=int,
                        help='manualseed for reproduction')
    parser.add_argument('--cuda',
                        default=True,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--ngpu',
                        default=1,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.0005,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--stepvalues',
                        default='70000,90000',
                        type=str,
                        help='iter number when learning rate to be dropped')
    parser.add_argument('--weight_decay',
                        default=5e-4,
                        type=float,
                        help='Weight decay for SGD')
    parser.add_argument('--gamma',
                        default=0.2,
                        type=float,
                        help='Gamma update for SGD')
    parser.add_argument('--log_iters',
                        default=True,
                        type=bool,
                        help='Print the loss at each iteration')
    parser.add_argument('--visdom',
                        default=False,
                        type=str2bool,
                        help='Use visdom to for loss visualization')
    parser.add_argument('--data_root',
                        default=relative_path + 'realtime/',
                        help='Location of VOC root directory')
    parser.add_argument('--save_root',
                        default=relative_path + 'realtime/saveucf24/',
                        help='Location to save checkpoint models')

    parser.add_argument('--iou_thresh',
                        default=0.5,
                        type=float,
                        help='Evaluation threshold')
    parser.add_argument('--conf_thresh',
                        default=0.01,
                        type=float,
                        help='Confidence threshold for evaluation')
    parser.add_argument('--nms_thresh',
                        default=0.45,
                        type=float,
                        help='NMS threshold')
    parser.add_argument('--topk',
                        default=50,
                        type=int,
                        help='topk for evaluation')
    parser.add_argument('--clip_gradient',
                        default=40,
                        type=float,
                        help='gradients clip')
    parser.add_argument('--resume',
                        default=None,
                        type=str,
                        help='Resume from checkpoint')
    parser.add_argument('--start_epoch',
                        default=0,
                        type=int,
                        help='start epoch')
    parser.add_argument('--epochs',
                        default=35,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--eval_freq',
                        default=2,
                        type=int,
                        metavar='N',
                        help='evaluation frequency (default: 5)')
    parser.add_argument('--snapshot_pref',
                        type=str,
                        default="ucf101_vgg16_ssd300_")
    parser.add_argument('--lr_milestones',
                        default=[-2, -5],
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--arch', type=str, default="VGG16")
    parser.add_argument('--Finetune_SSD', default=False, type=str)
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--gpus', nargs='+', type=int, default=[0, 1, 2, 3])

    print(__file__)
    file_name = (__file__).split('/')[-1]
    file_name = file_name.split('.')[0]
    print(file_name)
    ## Parse arguments
    args = parser.parse_args()
    ## set random seeds
    np.random.seed(args.man_seed)
    torch.manual_seed(args.man_seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.man_seed)

    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    args.cfg = v2
    args.train_sets = 'train'
    args.means = (104, 117, 123)
    num_classes = len(CLASSES) + 1
    args.num_classes = num_classes
    args.stepvalues = [int(val) for val in args.stepvalues.split(',')]
    args.loss_reset_step = 30
    args.eval_step = 10000
    args.print_step = 10
    args.data_root += args.dataset + '/'

    ## Define the experiment Name will used to same directory
    day = (time.strftime('%m-%d', time.localtime(time.time())))
    args.snapshot_pref = ('ucf101_CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}').format(
        args.dataset, args.modality, args.batch_size, args.basenet[:-14],
        int(args.lr * 100000)) + '_' + file_name + '_' + day
    print(args.snapshot_pref)

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    net = build_refine_ssd(300, args.num_classes)
    net = torch.nn.DataParallel(net, device_ids=args.gpus)

    if args.Finetune_SSD is True:
        print("load snapshot")
        pretrained_weights = "/data4/lilin/my_code/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"
        pretrained_dict = torch.load(pretrained_weights)
        model_dict = net.state_dict()  # 1. filter out unnecessary keys
        pretrained_dict_2 = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }  # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict_2)  # 3. load the new state dict
    elif args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            net.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    elif args.modality == 'fastOF':
        print(
            'Download pretrained brox flow trained model weights and place them at:::=> ',
            args.data_root + 'ucf24/train_data/brox_wieghts.pth')
        pretrained_weights = args.data_root + 'train_data/brox_wieghts.pth'
        print('Loading base network...')
        net.load_state_dict(torch.load(pretrained_weights))
    else:
        vgg_weights = torch.load(args.data_root + 'train_data/' + args.basenet)
        print('Loading base network...')
        net.module.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    # initialize newly added layers' weights with xavier method
    if args.Finetune_SSD is False and args.resume is None:
        print('Initializing weights for extra layers and HEADs...')
        net.module.clstm_1.apply(weights_init)
        net.module.clstm_2.apply(weights_init)
        net.module.extras_r.apply(weights_init)
        net.module.loc_r.apply(weights_init)
        net.module.conf_r.apply(weights_init)

        net.module.extras.apply(weights_init)
        net.module.loc.apply(weights_init)
        net.module.conf.apply(weights_init)

    parameter_dict = dict(net.named_parameters(
    ))  # Get parmeter of network in dictionary format wtih name being key
    params = []

    #Set different learning rate to bias layers and set their weight_decay to 0
    for name, param in parameter_dict.items():
        if name.find('vgg') > -1 and int(
                name.split('.')[2]) < 23:  # :and name.find('cell') <= -1
            param.requires_grad = False
            print(name, 'layer parameters will be fixed')
        else:
            if name.find('bias') > -1:
                print(
                    name, 'layer parameters will be trained @ {}'.format(
                        args.lr * 2))
                params += [{
                    'params': [param],
                    'lr': args.lr * 2,
                    'weight_decay': 0
                }]
            else:
                print(name,
                      'layer parameters will be trained @ {}'.format(args.lr))
                params += [{
                    'params': [param],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]

    optimizer = optim.SGD(params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = RecurrentMultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3,
                                      0.5, False, args.cuda)
    scheduler = None
    # scheduler = LogLR(optimizer, lr_milestones=args.lr_milestones, total_epoch=args.epochs)
    scheduler = MultiStepLR(optimizer,
                            milestones=args.stepvalues,
                            gamma=args.gamma)
    print('Loading Dataset...')
    num_gpu = len(args.gpus)

    rootpath = args.data_root
    imgtype = args.modality
    imagesDir = rootpath + imgtype + '/'
    split = 1
    splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split)
    trainvideos = readsplitfile(splitfile)

    splitfile = rootpath + 'splitfiles/testlist{:02d}.txt'.format(split)
    testvideos = readsplitfile(splitfile)

    ####### val dataset does not need shuffle #######
    val_data_loader = []
    len_test = len(testvideos)
    random.shuffle(testvideos)
    for i in range(num_gpu):
        testvideos_temp = testvideos[int(i * len_test /
                                         num_gpu):int((i + 1) * len_test /
                                                      num_gpu)]
        val_dataset = UCF24Detection(args.data_root,
                                     'test',
                                     BaseTransform(args.ssd_dim, args.means),
                                     AnnotationTransform(),
                                     input_type=args.modality,
                                     full_test=False,
                                     videos=testvideos_temp,
                                     istrain=False)
        val_data_loader.append(
            data.DataLoader(val_dataset,
                            args.batch_size,
                            num_workers=args.num_workers,
                            shuffle=False,
                            collate_fn=detection_collate,
                            pin_memory=True,
                            drop_last=True))

    log_file = open(
        args.save_root + args.snapshot_pref + "_training_" + day + ".log", "w",
        1)
    log_file.write(args.snapshot_pref + '\n')

    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg) + ': ' + str(getattr(args, arg)) + '\n')

    log_file.write(str(net))

    torch.cuda.synchronize()
    len_train = len(trainvideos)

    for epoch in range(args.start_epoch, args.epochs):
        ####### shuffle train dataset #######
        random.shuffle(trainvideos)
        train_data_loader = []
        for i in range(num_gpu):
            trainvideos_temp = trainvideos[int(i * len_train /
                                               num_gpu):int((i + 1) *
                                                            len_train /
                                                            num_gpu)]
            train_dataset = UCF24Detection(args.data_root,
                                           'train',
                                           SSDAugmentation(
                                               args.ssd_dim, args.means),
                                           AnnotationTransform(),
                                           input_type=args.modality,
                                           videos=trainvideos_temp,
                                           istrain=True)
            train_data_loader.append(
                data.DataLoader(train_dataset,
                                args.batch_size,
                                num_workers=args.num_workers,
                                shuffle=False,
                                collate_fn=detection_collate,
                                pin_memory=True,
                                drop_last=True))

        print("Train epoch_size: ", len(train_data_loader))
        print('Train SSD on', train_dataset.name)

        ########## train ###########
        train(train_data_loader, net, criterion, optimizer, scheduler, epoch,
              num_gpu)

        print('Saving state, epoch:', epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
            },
            epoch=epoch)

        #### log lr ###
        # scheduler.step()
        # evaluate on validation set
        if (
                epoch + 1
        ) % args.eval_freq == 0 or epoch == args.epochs - 1 or epoch == 0:  #
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            mAP, ap_all, ap_strs = validate(args,
                                            net,
                                            val_data_loader,
                                            val_dataset,
                                            epoch,
                                            iou_thresh=args.iou_thresh,
                                            num_gpu=num_gpu)
            # remember best prec@1 and save checkpoint
            is_best = mAP > best_prec1
            best_prec1 = max(mAP, best_prec1)
            print('Saving state, epoch:', epoch)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)

            for ap_str in ap_strs:
                print(ap_str)
                log_file.write(ap_str + '\n')
            ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
            print(ptr_str)
            log_file.write(ptr_str)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs)
            print(prt_str)
            log_file.write(ptr_str)

    log_file.close()
def main():
    global args, log_file, best_prec1
    relative_path = '/data4/lilin/my_code'
    parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training')
    parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer')
    parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model')
    parser.add_argument('--dataset', default='ucf24', help='pretrained base model')
    parser.add_argument('--ssd_dim', default=300, type=int, help='Input Size for SSD')  # only support 300 now
    parser.add_argument('--modality', default='rgb', type=str,
                        help='INput tyep default rgb options are [rgb,brox,fastOF]')
    parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
    parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading')
    parser.add_argument('--max_iter', default=120000, type=int, help='Number of training iterations')
    parser.add_argument('--man_seed', default=123, type=int, help='manualseed for reproduction')
    parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--ngpu', default=1, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--base_lr', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--lr', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
    parser.add_argument('--gamma', default=0.2, type=float, help='Gamma update for SGD')
    parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
    parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization')
    parser.add_argument('--data_root', default= relative_path + '/realtime/', help='Location of VOC root directory')
    parser.add_argument('--save_root', default= relative_path + '/realtime/saveucf24/',
                        help='Location to save checkpoint models')
    parser.add_argument('--iou_thresh', default=0.5, type=float, help='Evaluation threshold')
    parser.add_argument('--conf_thresh', default=0.01, type=float, help='Confidence threshold for evaluation')
    parser.add_argument('--nms_thresh', default=0.45, type=float, help='NMS threshold')
    parser.add_argument('--topk', default=50, type=int, help='topk for evaluation')
    parser.add_argument('--clip_gradient', default=40, type=float, help='gradients clip')
    parser.add_argument('--resume', default=None,type=str, help='Resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    parser.add_argument('--epochs', default=35, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--eval_freq', default=2, type=int, metavar='N', help='evaluation frequency (default: 5)')
    parser.add_argument('--snapshot_pref', type=str, default="ucf101_vgg16_ssd300_end2end")
    parser.add_argument('--lr_milestones', default=[-2, -5], type=float, help='initial learning rate')
    parser.add_argument('--arch', type=str, default="VGG16")
    parser.add_argument('--Finetune_SSD', default=False, type=str)
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument(
        '--step',
        type=int,
        default=[18, 27],
        nargs='+',
        help='the epoch where optimizer reduce the learning rate')
    parser.add_argument('--log_lr', default=False, type=str2bool, help='Use cuda to train model')
    parser.add_argument(
        '--print-log',
        type=str2bool,
        default=True,
        help='print logging or not')
    parser.add_argument(
        '--end2end',
        type=str2bool,
        default=False,
        help='print logging or not')

    ## Parse arguments
    args = parser.parse_args()

    print(__file__)

    print_log(args, this_file_name)
    ## set random seeds
    np.random.seed(args.man_seed)
    torch.manual_seed(args.man_seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.man_seed)

    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    args.cfg = v2
    args.train_sets = 'train'
    args.means = (104, 117, 123)
    num_classes = len(CLASSES) + 1
    args.num_classes = num_classes
    # args.step = [int(val) for val in args.step.split(',')]
    args.loss_reset_step = 30
    args.eval_step = 10000
    args.print_step = 10
    args.data_root += args.dataset + '/'

    ## Define the experiment Name will used to same directory
    args.snapshot_pref = ('ucf101_CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}').format(args.dataset,
                args.modality, args.batch_size, args.basenet[:-14], int(args.lr*100000)) # + '_' + file_name + '_' + day
    print_log(args, args.snapshot_pref)

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    net = build_ssd(300, args.num_classes)

    if args.Finetune_SSD is True:
        print_log(args, "load snapshot")
        pretrained_weights = "/home2/lin_li/zjg_code/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"
        pretrained_dict = torch.load(pretrained_weights)
        model_dict = net.state_dict()  # 1. filter out unnecessary keys
        pretrained_dict_2 = {k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict
        # pretrained_dict_2['vgg.25.bias'] = pretrained_dict['vgg.24.bias']
        # pretrained_dict_2['vgg.25.weight'] = pretrained_dict['vgg.24.weight']
        # pretrained_dict_2['vgg.27.bias'] = pretrained_dict['vgg.26.bias']
        # pretrained_dict_2['vgg.27.weight'] = pretrained_dict['vgg.26.weight']
        # pretrained_dict_2['vgg.29.bias'] = pretrained_dict['vgg.28.bias']
        # pretrained_dict_2['vgg.29.weight'] = pretrained_dict['vgg.28.weight']
        # pretrained_dict_2['vgg.32.bias'] = pretrained_dict['vgg.31.bias']
        # pretrained_dict_2['vgg.32.weight'] = pretrained_dict['vgg.31.weight']
        # pretrained_dict_2['vgg.34.bias'] = pretrained_dict['vgg.33.bias']
        # pretrained_dict_2['vgg.34.weight'] = pretrained_dict['vgg.33.weight']
        model_dict.update(pretrained_dict_2) # 3. load the new state dict
    elif args.resume is not None:
        if os.path.isfile(args.resume):
            print_log(args, ("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            if args.end2end is False:
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            net.load_state_dict(checkpoint['state_dict'])
            print_log(args, ("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print_log(args, ("=> no checkpoint found at '{}'".format(args.resume)))

    elif args.modality == 'fastOF':
        print_log(args, 'Download pretrained brox flow trained model weights and place them at:::=> ' + args.data_root + 'ucf24/train_data/brox_wieghts.pth')
        pretrained_weights = args.data_root + 'train_data/brox_wieghts.pth'
        print_log(args, 'Loading base network...')
        net.load_state_dict(torch.load(pretrained_weights))
    else:
        vgg_weights = torch.load(args.data_root +'train_data/' + args.basenet)
        print_log(args, 'Loading base network...')
        net.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    def xavier(param):
        init.xavier_uniform(param)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            m.bias.data.zero_()

    print_log(args, 'Initializing weights for extra layers and HEADs...')
    # initialize newly added layers' weights with xavier method
    if args.Finetune_SSD is False and args.resume is None:
        print_log(args, "init layers")
        net.clstm.apply(weights_init)
        net.extras.apply(weights_init)
        net.loc.apply(weights_init)
        net.conf.apply(weights_init)

    parameter_dict = dict(net.named_parameters()) # Get parmeter of network in dictionary format wtih name being key
    params = []

    #Set different learning rate to bias layers and set their weight_decay to 0
    for name, param in parameter_dict.items():
        # if args.end2end is False and name.find('vgg') > -1 and int(name.split('.')[1]) < 23:# :and name.find('cell') <= -1
        #     param.requires_grad = False
        #     print_log(args, name + 'layer parameters will be fixed')
        # else:
        if name.find('bias') > -1:
            print_log(args, name + 'layer parameters will be trained @ {}'.format(args.lr*2))
            params += [{'params': [param], 'lr': args.lr*2, 'weight_decay': 0}]
        else:
            print_log(args, name + 'layer parameters will be trained @ {}'.format(args.lr))
            params += [{'params':[param], 'lr': args.lr, 'weight_decay':args.weight_decay}]

    optimizer = optim.SGD(params, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3, 0.5, False, args.cuda)

    scheduler = None
    # scheduler = MultiStepLR(optimizer, milestones=args.step, gamma=args.gamma)
    rootpath = args.data_root
    split = 1
    splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split)
    trainvideos = readsplitfile(splitfile)

    splitfile = rootpath + 'splitfiles/testlist{:02d}.txt'.format(split)
    testvideos = readsplitfile(splitfile)


    print_log(args, 'Loading Dataset...')
    # train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means),
    #                                AnnotationTransform(), input_type=args.modality)
    # val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
    #                              AnnotationTransform(), input_type=args.modality,
    #                              full_test=False)

    # train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
    #                               shuffle=False, collate_fn=detection_collate, pin_memory=True)
    # val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
    #                              shuffle=False, collate_fn=detection_collate, pin_memory=True)

    len_test = len(testvideos)
    random.shuffle(testvideos)
    testvideos_temp = testvideos
    val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(), input_type=args.modality,
                                 full_test=False,
                                 videos=testvideos_temp,
                                 istrain=False)
    val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
                                           shuffle=False, collate_fn=detection_collate, pin_memory=True,
                                           drop_last=True)


    # print_log(args, "train epoch_size: " + str(len(train_data_loader)))
    # print_log(args, 'Training SSD on' + train_dataset.name)

    print_log(args, args.snapshot_pref)
    for arg in vars(args):
        print(arg, getattr(args, arg))
        print_log(args, str(arg)+': '+str(getattr(args, arg)))

    print_log(args, str(net))
    len_train = len(trainvideos)
    torch.cuda.synchronize()
    for epoch in range(args.start_epoch, args.epochs):

        random.shuffle(trainvideos)
        trainvideos_temp = trainvideos
        train_dataset = UCF24Detection(args.data_root, 'train', SSDAugmentation(args.ssd_dim, args.means),
                                       AnnotationTransform(),
                                       input_type=args.modality,
                                       videos=trainvideos_temp,
                                       istrain=True)
        train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
                                                 shuffle=False, collate_fn=detection_collate, pin_memory=True, drop_last=True)

        train(train_data_loader, net, criterion, optimizer, epoch, scheduler)
        print_log(args, 'Saving state, epoch:' + str(epoch))

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net.state_dict(),
            'best_prec1': best_prec1,
        }, epoch = epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, epoch, iou_thresh=args.iou_thresh)
            # remember best prec@1 and save checkpoint
            is_best = mAP > best_prec1
            best_prec1 = max(mAP, best_prec1)
            print_log(args, 'Saving state, epoch:' +str(epoch))
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
            }, is_best,epoch)

            for ap_str in ap_strs:
                print(ap_str)
                print_log(args, ap_str)
            ptr_str = '\nMEANAP:::=>'+str(mAP)
            print(ptr_str)
            # log_file.write()
            print_log(args, ptr_str)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0-tvs)
            print(prt_str)
            # log_file.write(ptr_str)
            print_log(args, ptr_str)
Exemplo n.º 3
0
def train(args, net, optimizer, criterion, scheduler):
    log_file = open(args.save_root+"training.log", "w", 1)
    log_file.write(args.exp_name+'\n')
    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg)+': '+str(getattr(args, arg))+'\n')
    log_file.write(str(net))
    net.train()

    # loss counters
    batch_time = AverageMeter()
    losses = AverageMeter()
    loc_losses = AverageMeter()
    cls_losses = AverageMeter()

    print('Loading Dataset...')
    # train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means),
    #                                AnnotationTransform(), input_type=args.input_type)
    # val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
    #                              AnnotationTransform(), input_type=args.input_type,
    #                              full_test=False)
    # epoch_size = len(train_dataset) // args.batch_size
    # print('Training SSD on', train_dataset.name)

    if args.visdom:

        import visdom
        viz = visdom.Visdom()
        viz.port = args.vis_port
        viz.env = args.exp_name
        # initialize visdom loss plot
        lot = viz.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1, 6)).cpu(),
            opts=dict(
                xlabel='Iteration',
                ylabel='Loss',
                title='Current SSD Training Loss',
                legend=['REG', 'CLS', 'AVG', 'S-REG', ' S-CLS', ' S-AVG']
            )
        )
        # initialize visdom meanAP and class APs plot
        legends = ['meanAP']
        for cls in CLASSES:
            legends.append(cls)
        val_lot = viz.line(
            X=torch.zeros((1,)).cpu(),
            Y=torch.zeros((1,args.num_classes)).cpu(),
            opts=dict(
                xlabel='Iteration',
                ylabel='Mean AP',
                title='Current SSD Validation mean AP',
                legend=legends
            )
        )


    batch_iterator = None
    rootpath = args.data_root
    imgtype = args.modality
    imagesDir = rootpath + imgtype + '/'
    split = 1
    splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split)
    trainvideos = readsplitfile(splitfile)

    splitfile = rootpath + 'splitfiles/testlist{:02d}.txt'.format(split)
    testvideos = readsplitfile(splitfile)

    # train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
    #                               shuffle=False, collate_fn=detection_collate, pin_memory=True)
    # val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
    #                              shuffle=False, collate_fn=detection_collate, pin_memory=True)
    val_data_loader = []
    len_test = len(testvideos)
    random.shuffle(testvideos)
    testvideos_temp = testvideos
    val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(), input_type=args.modality,
                                 full_test=False,
                                 videos=testvideos_temp,
                                 istrain=False)
    val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
                                           shuffle=False, collate_fn=detection_collate, pin_memory=True,
                                           drop_last=True)

    itr_count = 0
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    iteration = 0

    # train_shuffle = []
    # ii = len(train_data_loader)
    # ii = 0
    # for iteration in range(len(train_data_loader)):
    #     if not batch_iterator:
    #         batch_iterator = iter(train_data_loader)
    #     # load train data
    #     images, targets, img_indexs = next(batch_iterator)
    #     train_shuffle.append([images, targets, img_indexs])

    len_train = len(trainvideos)
    while iteration <= args.max_iter:
        # for i, (images, targets, img_indexs) in enumerate(train_data_loader):
        random.shuffle(trainvideos)
        trainvideos_temp = trainvideos
        train_dataset = UCF24Detection(args.data_root, 'train', SSDAugmentation(args.ssd_dim, args.means),
                                       AnnotationTransform(),
                                       input_type=args.modality,
                                       videos=trainvideos_temp,
                                       istrain=True)
        train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
                                                 shuffle=False, collate_fn=detection_collate, pin_memory=True, drop_last=True)

        for i, item in enumerate(train_data_loader):
            images = item[0]
            targets = item[1]
            img_indexs = item[2]

            if iteration > args.max_iter:
                break
            iteration += 1
            if args.cuda:
                images = Variable(images.cuda())
                targets = [Variable(anno.cuda(), volatile=True) for anno in targets]
            else:
                images = Variable(images)
                targets = [Variable(anno, volatile=True) for anno in targets]
            # forward
            out = net(images, img_indexs)
            # backprop
            optimizer.zero_grad()

            loss_l, loss_c = criterion(out, targets)
            loss = loss_l + loss_c
            loss.backward()
            if args.clip_gradient is not None:
                total_norm = clip_grad_norm(net.parameters(), args.clip_gradient)
                if total_norm > args.clip_gradient:
                    print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm))

            optimizer.step()
            scheduler.step()
            loc_loss = loss_l.data[0]
            conf_loss = loss_c.data[0]
            # print('Loss data type ',type(loc_loss))
            loc_losses.update(loc_loss)
            cls_losses.update(conf_loss)
            losses.update((loc_loss + conf_loss)/2.0)


            if iteration % args.print_step == 0 and iteration>0:
                if args.visdom:
                    losses_list = [loc_losses.val, cls_losses.val, losses.val, loc_losses.avg, cls_losses.avg, losses.avg]
                    viz.line(X=torch.ones((1, 6)).cpu() * iteration,
                        Y=torch.from_numpy(np.asarray(losses_list)).unsqueeze(0).cpu(),
                        win=lot,
                        update='append')


                torch.cuda.synchronize()
                t1 = time.perf_counter()
                batch_time.update(t1 - t0)

                print_line = 'Itration {:06d}/{:06d} loc-loss {:.3f}({:.3f}) cls-loss {:.3f}({:.3f}) ' \
                             'average-loss {:.3f}({:.3f}) Timer {:0.3f}({:0.3f})'.format(
                              iteration, args.max_iter, loc_losses.val, loc_losses.avg, cls_losses.val,
                              cls_losses.avg, losses.val, losses.avg, batch_time.val, batch_time.avg)

                torch.cuda.synchronize()
                t0 = time.perf_counter()
                log_file.write(print_line+'\n')
                print(print_line)

                # if args.visdom and args.send_images_to_visdom:
                #     random_batch_index = np.random.randint(images.size(0))
                #     viz.image(images.data[random_batch_index].cpu().numpy())
                itr_count += 1

                if itr_count % args.loss_reset_step == 0 and itr_count > 0:
                    loc_losses.reset()
                    cls_losses.reset()
                    losses.reset()
                    batch_time.reset()
                    print('Reset accumulators of ', args.exp_name,' at', itr_count*args.print_step)
                    itr_count = 0

            if (iteration % args.eval_step == 0 or iteration == 5000) and iteration>0:
                torch.cuda.synchronize()
                tvs = time.perf_counter()
                print('Saving state, iter:', iteration)
                torch.save(net.state_dict(), args.save_root+'ssd300_ucf24_' +
                           repr(iteration) + '.pth')

                net.eval() # switch net to evaluation mode
                mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, iteration, iou_thresh=args.iou_thresh)

                for ap_str in ap_strs:
                    print(ap_str)
                    log_file.write(ap_str+'\n')
                ptr_str = '\nMEANAP:::=>'+str(mAP)+'\n'
                print(ptr_str)
                log_file.write(ptr_str)

                if args.visdom:
                    aps = [mAP]
                    for ap in ap_all:
                        aps.append(ap)
                    viz.line(
                        X=torch.ones((1, args.num_classes)).cpu() * iteration,
                        Y=torch.from_numpy(np.asarray(aps)).unsqueeze(0).cpu(),
                        win=val_lot,
                        update='append'
                            )
                net.train() # Switch net back to training mode
                torch.cuda.synchronize()
                t0 = time.perf_counter()
                prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0-tvs)
                print(prt_str)
                log_file.write(ptr_str)

    log_file.close()