Example #1
0
        os.makedirs(output_dir)

    # sampler_batch = sampler(train_size, args.batch_size)
    dataset = roibatchLoader(roidb)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=True)

    support_set_roidb = get_roidb(args.trimmed_support_set_data)

    # initialize the network here.
    if args.net == 'c3d':
        tdcnn_demo = c3d_tdcnn_fewshot(pretrained=True)
    elif args.net == 'res18':
        tdcnn_demo = resnet_tdcnn(depth=18, pretrained=True)
    elif args.net == 'res34':
        tdcnn_demo = resnet_tdcnn(depth=34, pretrained=True)
    elif args.net == 'res50':
        tdcnn_demo = resnet_tdcnn(depth=50, pretrained=True)
    elif args.net == 'eco':
        tdcnn_demo = eco_tdcnn(pretrained=True)
    else:
        print("network is not defined")

    tdcnn_demo.create_architecture()
    print(tdcnn_demo)

    params = []
    for key, value in dict(tdcnn_demo.named_parameters()).items():
        if value.requires_grad:
Example #2
0
    input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
    if not os.path.exists(input_dir):
        raise Exception(
            'There is no input directory for loading network from ' +
            input_dir)
    load_name = os.path.join(
        input_dir,
        'tdcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch,
                                    args.checkpoint))

    # initilize the network here.
    if args.net == 'c3d':
        tdcnn_demo = c3d_tdcnn(class_agnostic=cfg.AGNOSTIC, pretrained=False)
    elif args.net == 'res34':
        tdcnn_demo = resnet_tdcnn(depth=34,
                                  class_agnostic=cfg.AGNOSTIC,
                                  pretrained=False)
    elif args.net == 'res50':
        tdcnn_demo = resnet_tdcnn(depth=50,
                                  class_agnostic=cfg.AGNOSTIC,
                                  pretrained=False)
    else:
        print("network is not defined")
        pdb.set_trace()

    tdcnn_demo.create_architecture()

    print("load checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    tdcnn_demo.load_state_dict(checkpoint['model'])
    if 'pooling_mode' in checkpoint.keys():
    input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
    if not os.path.exists(input_dir):
        raise Exception(
            'There is no input directory for loading network from ' +
            input_dir)
    load_name = os.path.join(
        input_dir,
        'tdcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch,
                                    args.checkpoint))

    # initilize the network here.
    if args.net == 'c3d':
        tdcnn_demo = c3d_tdcnn(pretrained=False)
    elif args.net == 'res34':
        tdcnn_demo = resnet_tdcnn(depth=34, pretrained=False)
    elif args.net == 'res50':
        tdcnn_demo = resnet_tdcnn(depth=50, pretrained=False)
    else:
        print("network is not defined")
        pdb.set_trace()

    tdcnn_demo.create_architecture()
    print(tdcnn_demo)

    print("load checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    tdcnn_demo.load_state_dict(checkpoint['model'])
    if 'pooling_mode' in checkpoint.keys():
        cfg.POOLING_MODE = checkpoint['pooling_mode']
    print('load model successfully!')