def main():

    means = (104, 117, 123)  # only support voc now

    exp_name = 'CONV-SSD-{}-{}-bs-{}-lr-{:05d}'.format(args.dataset,
                                                       args.input_type,
                                                       args.batch_size,
                                                       int(args.lr * 100000))

    args.save_root += args.dataset + '/'
    args.data_root += args.dataset + '/'
    args.listid = '01'  ## would be usefull in JHMDB-21
    print('Exp name', exp_name, args.listid)
    for iteration in [int(itr) for itr in args.eval_iter.split(',')]:
        log_file = open(
            args.save_root + 'cache/' + exp_name +
            "/testing-{:d}.log".format(iteration), "w", 1)
        log_file.write(exp_name + '\n')
        trained_model_path = args.save_root + 'cache/' + exp_name + '/ssd300_ucf24_' + repr(
            iteration) + '.pth'
        log_file.write(trained_model_path + '\n')
        num_classes = len(CLASSES) + 1  #7 +1 background
        net = build_ssd("train", 300, num_classes)  # initialize SSD
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        vgg_weights = torch.load(trained_model_path)
        for k, v in vgg_weights.items():
            namekey = k[7:]  # remove `module.`
            new_state_dict[namekey] = v
        net.load_state_dict(new_state_dict)
        net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3])
        net.eval()
        if args.cuda:
            net = net.cuda()
            cudnn.benchmark = True
        print('Finished loading model %d !' % iteration)
        # Load dataset
        dataset = UCF24Detection(args.data_root,
                                 'test',
                                 BaseTransform(args.ssd_dim, means),
                                 AnnotationTransform(),
                                 input_type=args.input_type,
                                 full_test=True)
        # evaluation
        torch.cuda.synchronize()
        tt0 = time.perf_counter()
        log_file.write('Testing net \n')
        mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name,
                                        args.input_type, dataset, iteration,
                                        num_classes)
        for ap_str in ap_strs:
            print(ap_str)
            log_file.write(ap_str + '\n')
        ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
        print(ptr_str)
        log_file.write(ptr_str)

        torch.cuda.synchronize()
        print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0))
        log_file.close()
Example #2
0
def main():

    means = (104, 117, 123)  # only support voc now

    exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(
        args.dataset, args.input_type, args.batch_size, args.basenet[:-14],
        int(args.lr * 100000))

    args.save_root += args.dataset + '/'
    args.data_root += args.dataset + '/'
    args.listid = '01'  ## would be usefull in JHMDB-21
    print('Exp name', exp_name, args.listid)
    # for iteration in [int(itr) for itr in args.eval_iter.split(',')]:
    # log_file = open(args.save_root + 'cache/' + exp_name + "/testing-{:d}.log".format(iteration), "w", 1)
    # log_file.write(exp_name + '\n')
    # trained_model_path = args.save_root + 'cache/' + exp_name + '/ssd300_ucf24_' + repr(iteration) + '.pth'
    trained_model_path = "/data-sdb/data/jiagang.zhu/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"  ###0.6357
    # log_file.write(trained_model_path+'\n')
    num_classes = len(CLASSES) + 1  #7 +1 background
    net = build_ssd(300, num_classes)  # initialize SSD
    net.load_state_dict(torch.load(trained_model_path))
    net.eval()
    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    # print('Finished loading model %d !' % iteration)
    # Load dataset
    dataset = UCF24Detection(args.data_root,
                             'test',
                             BaseTransform(args.ssd_dim, means),
                             AnnotationTransform(),
                             input_type=args.input_type,
                             full_test=False)  ###full test = true 0.6357
    # evaluation
    torch.cuda.synchronize()
    tt0 = time.perf_counter()
    # log_file.write('Testing net \n')
    iteration = 100000
    mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name,
                                    args.input_type, dataset, iteration,
                                    num_classes)
    for ap_str in ap_strs:
        print(ap_str)
        # log_file.write(ap_str + '\n')
    ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
    print(ptr_str)
    # log_file.write(ptr_str)

    torch.cuda.synchronize()
    print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0))
Example #3
0
def main():

    means = (104, 117, 123)  # only support voc now

    exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(
        args.dataset, args.input_type, args.batch_size, args.basenet[:-14],
        int(args.lr * 100000))

    args.save_root += args.dataset + '/'
    args.data_root += args.dataset + '/'
    args.listid = '099'  ## would be usefull in JHMDB-21
    print('Exp name', exp_name, args.listid)
    iteration = 0

    trained_model_path = "/data4/lilin/my_code/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"

    num_classes = len(CLASSES) + 1  #7 +1 background
    net = build_ssd(300, num_classes)  # initialize SSD
    net.load_state_dict(torch.load(trained_model_path))
    net.eval()
    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    print('Finished loading model %d !' % iteration)
    # Load dataset
    dataset = UCF24Detection(args.data_root,
                             'test',
                             BaseTransform(args.ssd_dim, means),
                             AnnotationTransform(),
                             input_type=args.input_type,
                             full_test=True)
    # evaluation
    torch.cuda.synchronize()
    tt0 = time.perf_counter()

    ptr_str = '\niou_thresh:::=>' + str(args.iou_thresh) + '\n'
    print(ptr_str)

    mAP, ap_all, ap_strs = te_net(net, args.save_root, exp_name,
                                  args.input_type, dataset, iteration,
                                  num_classes, args.iou_thresh)
    for ap_str in ap_strs:
        print(ap_str)

    ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
    print(ptr_str)

    torch.cuda.synchronize()
    print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0))
def action_detection_images(num_classes, means_bgr, li_color_class):

    exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(
        args.dataset, args.input_type, args.batch_size, args.basenet[:-14],
        int(args.lr * 100000))
    print('Exp name', exp_name, args.listid)
    for iteration in [int(itr) for itr in args.eval_iter.split(',')]:
        log_file = open(
            args.save_root + 'cache/' + exp_name +
            "/testing-{:d}.log".format(iteration), "w", 1)
        log_file.write(exp_name + '\n')
        #trained_model_path = args.save_root + 'cache/' + exp_name + '/ssd300_ucf24_' + repr(iteration) + '.pth'
        trained_model_path = args.save_root + 'cache/' + exp_name + '/' + args.input_type + '-ssd300_ucf24_' + repr(
            iteration) + '.pth'
        log_file.write(trained_model_path + '\n')
        net = init_ssd(num_classes, trained_model_path, args.cuda)
        print('Finished loading model %d !' % iteration)
        # Load dataset
        dataset = UCF24Detection(args.data_root,
                                 'test',
                                 BaseTransform(args.ssd_dim, means_bgr),
                                 AnnotationTransform(),
                                 input_type=args.input_type,
                                 full_test=True)
        #print('dataset.CLASSES : ', dataset.CLASSES);   exit()
        # evaluation
        torch.cuda.synchronize()
        tt0 = time.perf_counter()
        log_file.write('Testing net \n')
        #mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name, args.input_type, dataset, iteration, num_classes)
        mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name,
                                        args.input_type, dataset, iteration,
                                        li_color_class, means_bgr,
                                        args.n_record, args.iou_thresh)
        for ap_str in ap_strs:
            log_file.write(ap_str + '\n')
        ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
        print(ptr_str)
        log_file.write(ptr_str)

        torch.cuda.synchronize()
        print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0))
        log_file.close()
    return
def train(args, net, optimizer, criterion, scheduler):
    log_file = open(args.save_root + "training.log", "w", 1)
    log_file.write(args.exp_name + '\n')
    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg) + ': ' + str(getattr(args, arg)) + '\n')
    log_file.write(str(net))
    net.train()

    # loss counters
    batch_time = AverageMeter()
    losses = AverageMeter()
    loc_losses = AverageMeter()
    cls_losses = AverageMeter()

    print('Loading Dataset...')
    train_dataset = UCF24Detection(args.data_root,
                                   args.train_sets,
                                   SSDAugmentation(args.ssd_dim, args.means),
                                   AnnotationTransform(),
                                   input_type=args.input_type)
    val_dataset = UCF24Detection(args.data_root,
                                 'test',
                                 BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(),
                                 input_type=args.input_type,
                                 full_test=False)
    epoch_size = len(train_dataset) // args.batch_size
    print('Training SSD on', train_dataset.name)

    if args.visdom:

        import visdom
        viz = visdom.Visdom()
        viz.port = 8097
        viz.env = args.exp_name
        # initialize visdom loss plot
        lot = viz.line(X=torch.zeros((1, )).cpu(),
                       Y=torch.zeros((1, 6)).cpu(),
                       opts=dict(xlabel='Iteration',
                                 ylabel='Loss',
                                 title='Current SSD Training Loss',
                                 legend=[
                                     'REG', 'CLS', 'AVG', 'S-REG', ' S-CLS',
                                     ' S-AVG'
                                 ]))
        # initialize visdom meanAP and class APs plot
        legends = ['meanAP']
        for cls in CLASSES:
            legends.append(cls)
        val_lot = viz.line(X=torch.zeros((1, )).cpu(),
                           Y=torch.zeros((1, args.num_classes)).cpu(),
                           opts=dict(xlabel='Iteration',
                                     ylabel='Mean AP',
                                     title='Current SSD Validation mean AP',
                                     legend=legends))

    batch_iterator = None
    train_data_loader = data.DataLoader(train_dataset,
                                        args.batch_size,
                                        num_workers=args.num_workers,
                                        shuffle=True,
                                        collate_fn=detection_collate,
                                        pin_memory=True)
    val_data_loader = data.DataLoader(val_dataset,
                                      args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False,
                                      collate_fn=detection_collate,
                                      pin_memory=True)
    itr_count = 0
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for iteration in range(args.max_iter + 1):
        if (not batch_iterator) or (iteration % epoch_size == 0):
            # create batch iterator
            batch_iterator = iter(train_data_loader)

        # load train data
        images, targets, img_indexs = next(batch_iterator)
        if args.cuda:
            images = Variable(images.cuda())
            targets = [
                Variable(anno.cuda(), volatile=True) for anno in targets
            ]
        else:
            images = Variable(images)
            targets = [Variable(anno, volatile=True) for anno in targets]
        # forward
        out = net(images)
        # backprop
        optimizer.zero_grad()

        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        scheduler.step()
        loc_loss = loss_l.data[0]
        conf_loss = loss_c.data[0]
        # print('Loss data type ',type(loc_loss))
        loc_losses.update(loc_loss)
        cls_losses.update(conf_loss)
        losses.update((loc_loss + conf_loss) / 2.0)

        if iteration % args.print_step == 0 and iteration > 0:
            if args.visdom:
                losses_list = [
                    loc_losses.val, cls_losses.val, losses.val, loc_losses.avg,
                    cls_losses.avg, losses.avg
                ]
                viz.line(X=torch.ones((1, 6)).cpu() * iteration,
                         Y=torch.from_numpy(
                             np.asarray(losses_list)).unsqueeze(0).cpu(),
                         win=lot,
                         update='append')

            torch.cuda.synchronize()
            t1 = time.perf_counter()
            batch_time.update(t1 - t0)

            print_line = 'Itration {:06d}/{:06d} loc-loss {:.3f}({:.3f}) cls-loss {:.3f}({:.3f}) ' \
                         'average-loss {:.3f}({:.3f}) Timer {:0.3f}({:0.3f})'.format(
                          iteration, args.max_iter, loc_losses.val, loc_losses.avg, cls_losses.val,
                          cls_losses.avg, losses.val, losses.avg, batch_time.val, batch_time.avg)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            log_file.write(print_line + '\n')
            print(print_line)

            # if args.visdom and args.send_images_to_visdom:
            #     random_batch_index = np.random.randint(images.size(0))
            #     viz.image(images.data[random_batch_index].cpu().numpy())
            itr_count += 1

            if itr_count % args.loss_reset_step == 0 and itr_count > 0:
                loc_losses.reset()
                cls_losses.reset()
                losses.reset()
                batch_time.reset()
                print('Reset accumulators of ', args.exp_name, ' at',
                      itr_count * args.print_step)
                itr_count = 0

        if (iteration % args.eval_step == 0
                or iteration == 5000) and iteration > 0:
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            print('Saving state, iter:', iteration)
            torch.save(
                net.state_dict(),
                args.save_root + 'ssd300_ucf24_' + repr(iteration) + '.pth')

            net.eval()  # switch net to evaluation mode
            mAP, ap_all, ap_strs = validate(args,
                                            net,
                                            val_data_loader,
                                            val_dataset,
                                            iteration,
                                            iou_thresh=args.iou_thresh)

            for ap_str in ap_strs:
                print(ap_str)
                log_file.write(ap_str + '\n')
            ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
            print(ptr_str)
            log_file.write(ptr_str)

            if args.visdom:
                aps = [mAP]
                for ap in ap_all:
                    aps.append(ap)
                viz.line(X=torch.ones((1, args.num_classes)).cpu() * iteration,
                         Y=torch.from_numpy(
                             np.asarray(aps)).unsqueeze(0).cpu(),
                         win=val_lot,
                         update='append')
            net.train()  # Switch net back to training mode
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs)
            print(prt_str)
            log_file.write(ptr_str)

    log_file.close()
def main():
    global args, log_file, best_prec1
    relative_path = '/data4/lilin/my_code'
    parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training')
    parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer')
    parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model')
    parser.add_argument('--dataset', default='ucf24', help='pretrained base model')
    parser.add_argument('--ssd_dim', default=300, type=int, help='Input Size for SSD')  # only support 300 now
    parser.add_argument('--modality', default='rgb', type=str,
                        help='INput tyep default rgb options are [rgb,brox,fastOF]')
    parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
    parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading')
    parser.add_argument('--max_iter', default=120000, type=int, help='Number of training iterations')
    parser.add_argument('--man_seed', default=123, type=int, help='manualseed for reproduction')
    parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--ngpu', default=1, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--base_lr', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--lr', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
    parser.add_argument('--gamma', default=0.2, type=float, help='Gamma update for SGD')
    parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
    parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization')
    parser.add_argument('--data_root', default= relative_path + '/realtime/', help='Location of VOC root directory')
    parser.add_argument('--save_root', default= relative_path + '/realtime/saveucf24/',
                        help='Location to save checkpoint models')
    parser.add_argument('--iou_thresh', default=0.5, type=float, help='Evaluation threshold')
    parser.add_argument('--conf_thresh', default=0.01, type=float, help='Confidence threshold for evaluation')
    parser.add_argument('--nms_thresh', default=0.45, type=float, help='NMS threshold')
    parser.add_argument('--topk', default=50, type=int, help='topk for evaluation')
    parser.add_argument('--clip_gradient', default=40, type=float, help='gradients clip')
    parser.add_argument('--resume', default=None,type=str, help='Resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    parser.add_argument('--epochs', default=35, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--eval_freq', default=2, type=int, metavar='N', help='evaluation frequency (default: 5)')
    parser.add_argument('--snapshot_pref', type=str, default="ucf101_vgg16_ssd300_end2end")
    parser.add_argument('--lr_milestones', default=[-2, -5], type=float, help='initial learning rate')
    parser.add_argument('--arch', type=str, default="VGG16")
    parser.add_argument('--Finetune_SSD', default=False, type=str)
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument(
        '--step',
        type=int,
        default=[18, 27],
        nargs='+',
        help='the epoch where optimizer reduce the learning rate')
    parser.add_argument('--log_lr', default=False, type=str2bool, help='Use cuda to train model')
    parser.add_argument(
        '--print-log',
        type=str2bool,
        default=True,
        help='print logging or not')
    parser.add_argument(
        '--end2end',
        type=str2bool,
        default=False,
        help='print logging or not')

    ## Parse arguments
    args = parser.parse_args()

    print(__file__)

    print_log(args, this_file_name)
    ## set random seeds
    np.random.seed(args.man_seed)
    torch.manual_seed(args.man_seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.man_seed)

    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    args.cfg = v2
    args.train_sets = 'train'
    args.means = (104, 117, 123)
    num_classes = len(CLASSES) + 1
    args.num_classes = num_classes
    # args.step = [int(val) for val in args.step.split(',')]
    args.loss_reset_step = 30
    args.eval_step = 10000
    args.print_step = 10
    args.data_root += args.dataset + '/'

    ## Define the experiment Name will used to same directory
    args.snapshot_pref = ('ucf101_CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}').format(args.dataset,
                args.modality, args.batch_size, args.basenet[:-14], int(args.lr*100000)) # + '_' + file_name + '_' + day
    print_log(args, args.snapshot_pref)

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    net = build_ssd(300, args.num_classes)

    if args.Finetune_SSD is True:
        print_log(args, "load snapshot")
        pretrained_weights = "/home2/lin_li/zjg_code/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"
        pretrained_dict = torch.load(pretrained_weights)
        model_dict = net.state_dict()  # 1. filter out unnecessary keys
        pretrained_dict_2 = {k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict
        # pretrained_dict_2['vgg.25.bias'] = pretrained_dict['vgg.24.bias']
        # pretrained_dict_2['vgg.25.weight'] = pretrained_dict['vgg.24.weight']
        # pretrained_dict_2['vgg.27.bias'] = pretrained_dict['vgg.26.bias']
        # pretrained_dict_2['vgg.27.weight'] = pretrained_dict['vgg.26.weight']
        # pretrained_dict_2['vgg.29.bias'] = pretrained_dict['vgg.28.bias']
        # pretrained_dict_2['vgg.29.weight'] = pretrained_dict['vgg.28.weight']
        # pretrained_dict_2['vgg.32.bias'] = pretrained_dict['vgg.31.bias']
        # pretrained_dict_2['vgg.32.weight'] = pretrained_dict['vgg.31.weight']
        # pretrained_dict_2['vgg.34.bias'] = pretrained_dict['vgg.33.bias']
        # pretrained_dict_2['vgg.34.weight'] = pretrained_dict['vgg.33.weight']
        model_dict.update(pretrained_dict_2) # 3. load the new state dict
    elif args.resume is not None:
        if os.path.isfile(args.resume):
            print_log(args, ("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            if args.end2end is False:
                args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            net.load_state_dict(checkpoint['state_dict'])
            print_log(args, ("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print_log(args, ("=> no checkpoint found at '{}'".format(args.resume)))

    elif args.modality == 'fastOF':
        print_log(args, 'Download pretrained brox flow trained model weights and place them at:::=> ' + args.data_root + 'ucf24/train_data/brox_wieghts.pth')
        pretrained_weights = args.data_root + 'train_data/brox_wieghts.pth'
        print_log(args, 'Loading base network...')
        net.load_state_dict(torch.load(pretrained_weights))
    else:
        vgg_weights = torch.load(args.data_root +'train_data/' + args.basenet)
        print_log(args, 'Loading base network...')
        net.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    def xavier(param):
        init.xavier_uniform(param)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            m.bias.data.zero_()

    print_log(args, 'Initializing weights for extra layers and HEADs...')
    # initialize newly added layers' weights with xavier method
    if args.Finetune_SSD is False and args.resume is None:
        print_log(args, "init layers")
        net.clstm.apply(weights_init)
        net.extras.apply(weights_init)
        net.loc.apply(weights_init)
        net.conf.apply(weights_init)

    parameter_dict = dict(net.named_parameters()) # Get parmeter of network in dictionary format wtih name being key
    params = []

    #Set different learning rate to bias layers and set their weight_decay to 0
    for name, param in parameter_dict.items():
        # if args.end2end is False and name.find('vgg') > -1 and int(name.split('.')[1]) < 23:# :and name.find('cell') <= -1
        #     param.requires_grad = False
        #     print_log(args, name + 'layer parameters will be fixed')
        # else:
        if name.find('bias') > -1:
            print_log(args, name + 'layer parameters will be trained @ {}'.format(args.lr*2))
            params += [{'params': [param], 'lr': args.lr*2, 'weight_decay': 0}]
        else:
            print_log(args, name + 'layer parameters will be trained @ {}'.format(args.lr))
            params += [{'params':[param], 'lr': args.lr, 'weight_decay':args.weight_decay}]

    optimizer = optim.SGD(params, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3, 0.5, False, args.cuda)

    scheduler = None
    # scheduler = MultiStepLR(optimizer, milestones=args.step, gamma=args.gamma)
    rootpath = args.data_root
    split = 1
    splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split)
    trainvideos = readsplitfile(splitfile)

    splitfile = rootpath + 'splitfiles/testlist{:02d}.txt'.format(split)
    testvideos = readsplitfile(splitfile)


    print_log(args, 'Loading Dataset...')
    # train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means),
    #                                AnnotationTransform(), input_type=args.modality)
    # val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
    #                              AnnotationTransform(), input_type=args.modality,
    #                              full_test=False)

    # train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
    #                               shuffle=False, collate_fn=detection_collate, pin_memory=True)
    # val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
    #                              shuffle=False, collate_fn=detection_collate, pin_memory=True)

    len_test = len(testvideos)
    random.shuffle(testvideos)
    testvideos_temp = testvideos
    val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(), input_type=args.modality,
                                 full_test=False,
                                 videos=testvideos_temp,
                                 istrain=False)
    val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
                                           shuffle=False, collate_fn=detection_collate, pin_memory=True,
                                           drop_last=True)


    # print_log(args, "train epoch_size: " + str(len(train_data_loader)))
    # print_log(args, 'Training SSD on' + train_dataset.name)

    print_log(args, args.snapshot_pref)
    for arg in vars(args):
        print(arg, getattr(args, arg))
        print_log(args, str(arg)+': '+str(getattr(args, arg)))

    print_log(args, str(net))
    len_train = len(trainvideos)
    torch.cuda.synchronize()
    for epoch in range(args.start_epoch, args.epochs):

        random.shuffle(trainvideos)
        trainvideos_temp = trainvideos
        train_dataset = UCF24Detection(args.data_root, 'train', SSDAugmentation(args.ssd_dim, args.means),
                                       AnnotationTransform(),
                                       input_type=args.modality,
                                       videos=trainvideos_temp,
                                       istrain=True)
        train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
                                                 shuffle=False, collate_fn=detection_collate, pin_memory=True, drop_last=True)

        train(train_data_loader, net, criterion, optimizer, epoch, scheduler)
        print_log(args, 'Saving state, epoch:' + str(epoch))

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net.state_dict(),
            'best_prec1': best_prec1,
        }, epoch = epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, epoch, iou_thresh=args.iou_thresh)
            # remember best prec@1 and save checkpoint
            is_best = mAP > best_prec1
            best_prec1 = max(mAP, best_prec1)
            print_log(args, 'Saving state, epoch:' +str(epoch))
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
            }, is_best,epoch)

            for ap_str in ap_strs:
                print(ap_str)
                print_log(args, ap_str)
            ptr_str = '\nMEANAP:::=>'+str(mAP)
            print(ptr_str)
            # log_file.write()
            print_log(args, ptr_str)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0-tvs)
            print(prt_str)
            # log_file.write(ptr_str)
            print_log(args, ptr_str)
def main():
    global my_dict, keys, k_len, arr, xxx, args, log_file, best_prec1

    parser = argparse.ArgumentParser(
        description='Single Shot MultiBox Detector Training')
    parser.add_argument('--version',
                        default='v2',
                        help='conv11_2(v2) or pool6(v1) as last layer')
    parser.add_argument('--basenet',
                        default='vgg16_reducedfc.pth',
                        help='pretrained base model')
    parser.add_argument('--dataset',
                        default='ucf24',
                        help='pretrained base model')
    parser.add_argument('--ssd_dim',
                        default=300,
                        type=int,
                        help='Input Size for SSD')  # only support 300 now
    parser.add_argument(
        '--modality',
        default='rgb',
        type=str,
        help='INput tyep default rgb options are [rgb,brox,fastOF]')
    parser.add_argument('--jaccard_threshold',
                        default=0.5,
                        type=float,
                        help='Min Jaccard index for matching')
    parser.add_argument('--batch_size',
                        default=40,
                        type=int,
                        help='Batch size for training')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='Number of workers used in dataloading')
    parser.add_argument('--max_iter',
                        default=120000,
                        type=int,
                        help='Number of training iterations')
    parser.add_argument('--man_seed',
                        default=123,
                        type=int,
                        help='manualseed for reproduction')
    parser.add_argument('--cuda',
                        default=True,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--ngpu',
                        default=1,
                        type=str2bool,
                        help='Use cuda to train model')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.0005,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--stepvalues',
                        default='70000,90000',
                        type=str,
                        help='iter number when learning rate to be dropped')
    parser.add_argument('--weight_decay',
                        default=5e-4,
                        type=float,
                        help='Weight decay for SGD')
    parser.add_argument('--gamma',
                        default=0.2,
                        type=float,
                        help='Gamma update for SGD')
    parser.add_argument('--log_iters',
                        default=True,
                        type=bool,
                        help='Print the loss at each iteration')
    parser.add_argument('--visdom',
                        default=False,
                        type=str2bool,
                        help='Use visdom to for loss visualization')
    parser.add_argument('--data_root',
                        default=relative_path + 'realtime/',
                        help='Location of VOC root directory')
    parser.add_argument('--save_root',
                        default=relative_path + 'realtime/saveucf24/',
                        help='Location to save checkpoint models')

    parser.add_argument('--iou_thresh',
                        default=0.5,
                        type=float,
                        help='Evaluation threshold')
    parser.add_argument('--conf_thresh',
                        default=0.01,
                        type=float,
                        help='Confidence threshold for evaluation')
    parser.add_argument('--nms_thresh',
                        default=0.45,
                        type=float,
                        help='NMS threshold')
    parser.add_argument('--topk',
                        default=50,
                        type=int,
                        help='topk for evaluation')
    parser.add_argument('--clip_gradient',
                        default=40,
                        type=float,
                        help='gradients clip')
    parser.add_argument('--resume',
                        default=None,
                        type=str,
                        help='Resume from checkpoint')
    parser.add_argument('--start_epoch',
                        default=0,
                        type=int,
                        help='start epoch')
    parser.add_argument('--epochs',
                        default=35,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--eval_freq',
                        default=2,
                        type=int,
                        metavar='N',
                        help='evaluation frequency (default: 5)')
    parser.add_argument('--snapshot_pref',
                        type=str,
                        default="ucf101_vgg16_ssd300_")
    parser.add_argument('--lr_milestones',
                        default=[-2, -5],
                        type=float,
                        help='initial learning rate')
    parser.add_argument('--arch', type=str, default="VGG16")
    parser.add_argument('--Finetune_SSD', default=False, type=str)
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--gpus', nargs='+', type=int, default=[0, 1, 2, 3])

    print(__file__)
    file_name = (__file__).split('/')[-1]
    file_name = file_name.split('.')[0]
    print(file_name)
    ## Parse arguments
    args = parser.parse_args()
    ## set random seeds
    np.random.seed(args.man_seed)
    torch.manual_seed(args.man_seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.man_seed)

    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    args.cfg = v2
    args.train_sets = 'train'
    args.means = (104, 117, 123)
    num_classes = len(CLASSES) + 1
    args.num_classes = num_classes
    args.stepvalues = [int(val) for val in args.stepvalues.split(',')]
    args.loss_reset_step = 30
    args.eval_step = 10000
    args.print_step = 10
    args.data_root += args.dataset + '/'

    ## Define the experiment Name will used to same directory
    day = (time.strftime('%m-%d', time.localtime(time.time())))
    args.snapshot_pref = ('ucf101_CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}').format(
        args.dataset, args.modality, args.batch_size, args.basenet[:-14],
        int(args.lr * 100000)) + '_' + file_name + '_' + day
    print(args.snapshot_pref)

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    net = build_refine_ssd(300, args.num_classes)
    net = torch.nn.DataParallel(net, device_ids=args.gpus)

    if args.Finetune_SSD is True:
        print("load snapshot")
        pretrained_weights = "/data4/lilin/my_code/realtime/ucf24/rgb-ssd300_ucf24_120000.pth"
        pretrained_dict = torch.load(pretrained_weights)
        model_dict = net.state_dict()  # 1. filter out unnecessary keys
        pretrained_dict_2 = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }  # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict_2)  # 3. load the new state dict
    elif args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            net.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    elif args.modality == 'fastOF':
        print(
            'Download pretrained brox flow trained model weights and place them at:::=> ',
            args.data_root + 'ucf24/train_data/brox_wieghts.pth')
        pretrained_weights = args.data_root + 'train_data/brox_wieghts.pth'
        print('Loading base network...')
        net.load_state_dict(torch.load(pretrained_weights))
    else:
        vgg_weights = torch.load(args.data_root + 'train_data/' + args.basenet)
        print('Loading base network...')
        net.module.vgg.load_state_dict(vgg_weights)

    if args.cuda:
        net = net.cuda()

    # initialize newly added layers' weights with xavier method
    if args.Finetune_SSD is False and args.resume is None:
        print('Initializing weights for extra layers and HEADs...')
        net.module.clstm_1.apply(weights_init)
        net.module.clstm_2.apply(weights_init)
        net.module.extras_r.apply(weights_init)
        net.module.loc_r.apply(weights_init)
        net.module.conf_r.apply(weights_init)

        net.module.extras.apply(weights_init)
        net.module.loc.apply(weights_init)
        net.module.conf.apply(weights_init)

    parameter_dict = dict(net.named_parameters(
    ))  # Get parmeter of network in dictionary format wtih name being key
    params = []

    #Set different learning rate to bias layers and set their weight_decay to 0
    for name, param in parameter_dict.items():
        if name.find('vgg') > -1 and int(
                name.split('.')[2]) < 23:  # :and name.find('cell') <= -1
            param.requires_grad = False
            print(name, 'layer parameters will be fixed')
        else:
            if name.find('bias') > -1:
                print(
                    name, 'layer parameters will be trained @ {}'.format(
                        args.lr * 2))
                params += [{
                    'params': [param],
                    'lr': args.lr * 2,
                    'weight_decay': 0
                }]
            else:
                print(name,
                      'layer parameters will be trained @ {}'.format(args.lr))
                params += [{
                    'params': [param],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]

    optimizer = optim.SGD(params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = RecurrentMultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3,
                                      0.5, False, args.cuda)
    scheduler = None
    # scheduler = LogLR(optimizer, lr_milestones=args.lr_milestones, total_epoch=args.epochs)
    scheduler = MultiStepLR(optimizer,
                            milestones=args.stepvalues,
                            gamma=args.gamma)
    print('Loading Dataset...')
    num_gpu = len(args.gpus)

    rootpath = args.data_root
    imgtype = args.modality
    imagesDir = rootpath + imgtype + '/'
    split = 1
    splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split)
    trainvideos = readsplitfile(splitfile)

    splitfile = rootpath + 'splitfiles/testlist{:02d}.txt'.format(split)
    testvideos = readsplitfile(splitfile)

    ####### val dataset does not need shuffle #######
    val_data_loader = []
    len_test = len(testvideos)
    random.shuffle(testvideos)
    for i in range(num_gpu):
        testvideos_temp = testvideos[int(i * len_test /
                                         num_gpu):int((i + 1) * len_test /
                                                      num_gpu)]
        val_dataset = UCF24Detection(args.data_root,
                                     'test',
                                     BaseTransform(args.ssd_dim, args.means),
                                     AnnotationTransform(),
                                     input_type=args.modality,
                                     full_test=False,
                                     videos=testvideos_temp,
                                     istrain=False)
        val_data_loader.append(
            data.DataLoader(val_dataset,
                            args.batch_size,
                            num_workers=args.num_workers,
                            shuffle=False,
                            collate_fn=detection_collate,
                            pin_memory=True,
                            drop_last=True))

    log_file = open(
        args.save_root + args.snapshot_pref + "_training_" + day + ".log", "w",
        1)
    log_file.write(args.snapshot_pref + '\n')

    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg) + ': ' + str(getattr(args, arg)) + '\n')

    log_file.write(str(net))

    torch.cuda.synchronize()
    len_train = len(trainvideos)

    for epoch in range(args.start_epoch, args.epochs):
        ####### shuffle train dataset #######
        random.shuffle(trainvideos)
        train_data_loader = []
        for i in range(num_gpu):
            trainvideos_temp = trainvideos[int(i * len_train /
                                               num_gpu):int((i + 1) *
                                                            len_train /
                                                            num_gpu)]
            train_dataset = UCF24Detection(args.data_root,
                                           'train',
                                           SSDAugmentation(
                                               args.ssd_dim, args.means),
                                           AnnotationTransform(),
                                           input_type=args.modality,
                                           videos=trainvideos_temp,
                                           istrain=True)
            train_data_loader.append(
                data.DataLoader(train_dataset,
                                args.batch_size,
                                num_workers=args.num_workers,
                                shuffle=False,
                                collate_fn=detection_collate,
                                pin_memory=True,
                                drop_last=True))

        print("Train epoch_size: ", len(train_data_loader))
        print('Train SSD on', train_dataset.name)

        ########## train ###########
        train(train_data_loader, net, criterion, optimizer, scheduler, epoch,
              num_gpu)

        print('Saving state, epoch:', epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
            },
            epoch=epoch)

        #### log lr ###
        # scheduler.step()
        # evaluate on validation set
        if (
                epoch + 1
        ) % args.eval_freq == 0 or epoch == args.epochs - 1 or epoch == 0:  #
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            mAP, ap_all, ap_strs = validate(args,
                                            net,
                                            val_data_loader,
                                            val_dataset,
                                            epoch,
                                            iou_thresh=args.iou_thresh,
                                            num_gpu=num_gpu)
            # remember best prec@1 and save checkpoint
            is_best = mAP > best_prec1
            best_prec1 = max(mAP, best_prec1)
            print('Saving state, epoch:', epoch)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)

            for ap_str in ap_strs:
                print(ap_str)
                log_file.write(ap_str + '\n')
            ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n'
            print(ptr_str)
            log_file.write(ptr_str)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0 - tvs)
            print(prt_str)
            log_file.write(ptr_str)

    log_file.close()
Example #8
0
 def __init__(self, root, image_set, transform=None, target_transform=None,
              dataset_name='ucf24', input_type='rgb', full_test=False, *args, **kwargs):
     self.UCF24 = UCF24Detection(root, image_set, transform, target_transform,
                                 dataset_name, input_type, full_test)
     super(OmniUCF24, self).__init__(self.UCF24, *args, **kwargs)
Example #9
0
import torch.utils.data as data
from torch.autograd import Variable

from data import v2, UCF24Detection, AnnotationTransform, detection_collate, CLASSES_JHMDB, BaseTransform
from utils.augmentations_ import SSDAugmentation


train_dataset = UCF24Detection("/mnt/data/Action/data/ucf24/ucf24/", 'train', SSDAugmentation(300, (104, 117, 123)),
                               AnnotationTransform(), input_type='rgb')
train_data_loader = data.DataLoader(train_dataset, 1, num_workers=1,
                                    shuffle=True, collate_fn=detection_collate, pin_memory=True)
for i, (images, targets, img_indexs) in enumerate(train_data_loader):
    images = Variable(images)
    targets = [Variable(anno, volatile=True) for anno in targets]
Example #10
0
def main():
    global my_dict, keys, k_len, arr, xxx, args, log_file, best_prec1

    parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training')
    parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer')
    parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model')
    parser.add_argument('--dataset', default='ucf24', help='pretrained base model')
    parser.add_argument('--ssd_dim', default=300, type=int, help='Input Size for SSD')  # only support 300 now
    parser.add_argument('--modality', default='rgb', type=str,
                        help='INput tyep default rgb options are [rgb,brox,fastOF]')
    parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
    parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading')
    parser.add_argument('--max_iter', default=120000, type=int, help='Number of training iterations')
    parser.add_argument('--man_seed', default=123, type=int, help='manualseed for reproduction')
    parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--ngpu', default=1, type=str2bool, help='Use cuda to train model')
    parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--stepvalues', default='70000,90000', type=str,
                        help='iter number when learning rate to be dropped')
    parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
    parser.add_argument('--gamma', default=0.2, type=float, help='Gamma update for SGD')
    parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
    parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization')
    parser.add_argument('--data_root', default='/data4/lilin/my_code/realtime/', help='Location of VOC root directory')
    parser.add_argument('--save_root', default='/data4/lilin/my_code/realtime/realtime-lstm/save',
                        help='Location to save checkpoint models')
    parser.add_argument('--iou_thresh', default=0.5, type=float, help='Evaluation threshold')
    parser.add_argument('--conf_thresh', default=0.01, type=float, help='Confidence threshold for evaluation')
    parser.add_argument('--nms_thresh', default=0.45, type=float, help='NMS threshold')
    parser.add_argument('--topk', default=50, type=int, help='topk for evaluation')
    parser.add_argument('--clip', default=40, type=float, help='gradients clip')
    # parser.add_argument('--resume', default="/data4/lilin/my_code/realtime/realtime-lstm/saveucf24/cache/CONV-SSD-ucf24-rgb-bs-32-vgg16-lr-00050/ssd300_ucf24_30000.pth",
    #                     type=str, help='Resume from checkpoint')
    parser.add_argument('--resume', default=None,
                        type=str, help='Resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    parser.add_argument('--epochs', default=35, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--eval_freq', default=2, type=int, metavar='N', help='evaluation frequency (default: 5)')
    parser.add_argument('--snapshot_pref', type=str, default="ucf101_vgg16_ssd300_")

    print(__file__)
    file_name = (__file__).split('/')[-1]
    file_name = file_name.split('.')[0]
    print(file_name)
    ## Parse arguments

    args = parser.parse_args()
    ## set random seeds
    np.random.seed(args.man_seed)
    torch.manual_seed(args.man_seed)
    if args.cuda:
        torch.cuda.manual_seed_all(args.man_seed)

    if args.cuda and torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    args.cfg = v2
    args.train_sets = 'train'
    args.means = (104, 117, 123)
    num_classes = len(CLASSES) + 1
    args.num_classes = num_classes
    args.stepvalues = [int(val) for val in args.stepvalues.split(',')]
    args.loss_reset_step = 30
    args.eval_step = 10000
    args.print_step = 10

    ## Define the experiment Name will used to same directory and ENV for visdom
    args.exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(args.dataset,
                args.modality, args.batch_size, args.basenet[:-14], int(args.lr*100000))

    args.save_root += args.dataset+'/'
    args.save_root = args.save_root+'cache/'+args.exp_name+'/'

    if not os.path.isdir(args.save_root):
        os.makedirs(args.save_root)

    net = build_ssd(300, args.num_classes)

    # if args.has_snapshot is True:
    #     print ("load snapshot")
    #     pretrained_weights = "/data4/lilin/my_code/realtime/realtime-lstm/saveucf24/cache/CONV-SSD-ucf24-rgb-bs-32-vgg16-lr-00050/ssd300_ucf24_30000.pth"
    #     net.load_state_dict(torch.load(pretrained_weights))
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            net.load_state_dict(checkpoint['state_dict'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch'])))
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))

    elif args.modality == 'fastOF':
        print('Download pretrained brox flow trained model weights and place them at:::=> ',args.data_root + 'ucf24/train_data/brox_wieghts.pth')
        pretrained_weights = args.data_root + 'ucf24/train_data/brox_wieghts.pth'
        print('Loading base network...')
        net.load_state_dict(torch.load(pretrained_weights))
    else:
        vgg_weights = torch.load(args.data_root +'ucf24/train_data/' + args.basenet)
        print('Loading base network...')
        net.vgg.load_state_dict(vgg_weights)

    args.data_root += args.dataset + '/'

    if args.cuda:
        net = net.cuda()

    def xavier(param):
        init.xavier_uniform(param)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            m.bias.data.zero_()

    print('Initializing weights for extra layers and HEADs...')
    # initialize newly added layers' weights with xavier method
    if args.resume is None:
        net.extras.apply(weights_init)
        net.loc.apply(weights_init)
        net.conf.apply(weights_init)

    parameter_dict = dict(net.named_parameters()) # Get parmeter of network in dictionary format wtih name being key
    params = []

    #Set different learning rate to bias layers and set their weight_decay to 0
    for name, param in parameter_dict.items():
        if name.find('bias') > -1:
            print(name, 'layer parameters will be trained @ {}'.format(args.lr*2))
            params += [{'params': [param], 'lr': args.lr*2, 'weight_decay': 0}]
        else:
            print(name, 'layer parameters will be trained @ {}'.format(args.lr))
            params += [{'params':[param], 'lr': args.lr, 'weight_decay':args.weight_decay}]

    optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3, 0.5, False, args.cuda)
    scheduler = None
    scheduler = MultiStepLR(optimizer, milestones=args.stepvalues, gamma=args.gamma)

    print('Loading Dataset...')
    train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means),
                                   AnnotationTransform(), input_type=args.modality)
    val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means),
                                 AnnotationTransform(), input_type=args.modality,
                                 full_test=False)
    args.epoch_size = len(train_dataset) // args.batch_size

    train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
                                  shuffle=False, collate_fn=detection_collate, pin_memory=True)


    val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers,
                                 shuffle=False, collate_fn=detection_collate, pin_memory=True)
    print ("epoch_size: ", args.epoch_size)
    print('Training SSD on', train_dataset.name)

    my_dict = copy.deepcopy(train_data_loader.dataset.train_vid_frame)
    keys = list(my_dict.keys())
    k_len = len(keys)
    arr = np.arange(k_len)
    xxx = copy.deepcopy(train_data_loader.dataset.ids)

    log_file = open(args.save_root+"training.log", "w", 1)
    log_file.write(args.exp_name+'\n')

    for arg in vars(args):
        print(arg, getattr(args, arg))
        log_file.write(str(arg)+': '+str(getattr(args, arg))+'\n')
    log_file.write(str(net))

    torch.cuda.synchronize()

    for epoch in range(args.start_epoch, args.epochs):

        train(train_data_loader, net, criterion, optimizer, scheduler, epoch)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            torch.cuda.synchronize()
            tvs = time.perf_counter()
            mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, epoch, iou_thresh=args.iou_thresh)
            # remember best prec@1 and save checkpoint
            is_best = mAP > best_prec1
            best_prec1 = max(mAP, best_prec1)
            print('Saving state, epoch:', epoch)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)

            for ap_str in ap_strs:
                print(ap_str)
                log_file.write(ap_str+'\n')
            ptr_str = '\nMEANAP:::=>'+str(mAP)+'\n'
            print(ptr_str)
            log_file.write(ptr_str)

            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0-tvs)
            print(prt_str)
            log_file.write(ptr_str)

    log_file.close()