コード例 #1
0
def main():
    args = parse_args()

    print('Called with args:')
    print(args)

    # Set main gpu
    # theano.sandbox.cuda.use(args.gpu_id)

    if args.cfg_files is not None:
        for cfg_file in args.cfg_files:
            cfg_from_file(cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    if not args.randomize:
        np.random.seed(cfg.CONST.RNG_SEED)

    if args.batch_size is not None:
        cfg_from_list(['CONST.BATCH_SIZE', args.batch_size])
    if args.iter is not None:
        cfg_from_list(['TRAIN.NUM_ITERATION', args.iter])
    if args.save_freq is not None:
        cfg_from_list(['TRAIN.SAVE_FREQ', args.save_freq])
    if args.valid_freq is not None:
        cfg_from_list(['TRAIN.VALIDATION_FREQ', args.valid_freq])
    if args.nan_check_freq is not None:
        cfg_from_list(['TRAIN.NAN_CHECK_FREQ', args.nan_check_freq])
    if args.net_name is not None:
        cfg_from_list(['NET_NAME', args.net_name])
    if args.model_name is not None:
        cfg_from_list(['CONST.NETWORK_CLASS', args.model_name])
    if args.dataset is not None:
        cfg_from_list(['DATASET', args.dataset])
    if args.exp is not None:
        cfg_from_list(['TEST.EXP_NAME', args.exp])
    if args.out_path is not None:
        cfg_from_list(['DIR.OUT_PATH', args.out_path])
    if args.weights is not None:
        cfg_from_list([
            'CONST.WEIGHTS', args.weights, 'TRAIN.RESUME_TRAIN', True,
            'TRAIN.INITIAL_ITERATION',
            int(args.init_iter)
        ])

    print('Using config:')
    pprint.pprint(cfg)

    if not args.test:
        train_net()
    else:
        test_net()
コード例 #2
0
def main():
    args = parse_args()

    print('Called with args:')
    print(args)

    # Set main gpu
    #theano.sandbox.cuda.use(args.gpu_id)
    #theano.gpuarray.use(args.gpu_id)

    if args.cfg_files is not None:
        for cfg_file in args.cfg_files:
            cfg_from_file(cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    if not args.randomize:
        np.random.seed(cfg.CONST.RNG_SEED)

    if args.batch_size is not None:
        cfg_from_list(['CONST.BATCH_SIZE', args.batch_size])
    if args.iter is not None:
        cfg_from_list(['TRAIN.NUM_ITERATION', args.iter])
    if args.net_name is not None:
        cfg_from_list(['NET_NAME', args.net_name])
    if args.model_name is not None:
        cfg_from_list(['CONST.NETWORK_CLASS', args.model_name])
    if args.dataset is not None:
        cfg_from_list(['DATASET', args.dataset])
    if args.exp is not None:
        cfg_from_list(['TEST.EXP_NAME', args.exp])
    if args.out_path is not None:
        cfg_from_list(['DIR.OUT_PATH', args.out_path])
    if args.tb_path is not None:
        cfg_from_list(['DIR.TB_PATH', args.tb_path])
    if args.dyna_dict is not None:
        cfg_from_list(['CONST.dynamic_dict', args.dyna_dict])
    if args.learn_rate is not None:
        cfg_from_list(['TRAIN.DEFAULT_LEARNING_RATE', args.learn_rate])
    if args.weights is not None:
        cfg_from_list(['CONST.WEIGHTS', args.weights, 'TRAIN.RESUME_TRAIN', True,
                       'TRAIN.INITIAL_ITERATION', int(args.init_iter)])

    print('Using config:')
    pprint.pprint(cfg)

    if not args.test:
        train_net()
    else:
        test_net()
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('dataset',
                        help='dataset ('
                        'shapenet'
                        ', '
                        'primitives'
                        ')')
    parser.add_argument('embeddings_path',
                        help='path to text embeddings pickle file')
    parser.add_argument('--metric',
                        help='path to text embeddings pickle file',
                        default='minkowski',
                        type=str)
    parser.add_argument('--cfg',
                        dest='cfg_files',
                        action='append',
                        help='optional config file',
                        default=None,
                        type=str)
    args = parser.parse_args()

    # modify default config if requested
    if args.cfg_files is not None:
        for cfg_file in args.cfg_files:
            cfg_from_file(cfg_file)

    cfg_from_list(['CONST.DATASET', args.dataset])

    with open(args.embeddings_path, 'rb') as f:
        embeddings_dict = pickle.load(f)

    if os.path.basename(args.embeddings_path) == 'text_embeddings.p':
        subdir = 'text'
    elif ((os.path.basename(args.embeddings_path) == 'shape_embeddings.p')
          or (os.path.basename(args.embeddings_path)
              == 'modified_shape_embeddings.p')):
        subdir = 'shape'
    else:
        subdir = 'unspecified'
    render_dir = os.path.join(os.path.dirname(args.embeddings_path),
                              'nearest_neighbor_renderings', subdir)
    np.random.seed(1234)
    compute_metrics(args.dataset,
                    embeddings_dict,
                    metric=args.metric,
                    concise=render_dir)
コード例 #4
0
def main():
    parser = argparse.ArgumentParser(description="SMDL: SubModular Dataloader")
    parser.add_argument("--cfg",
                        dest='cfg_file',
                        default='./config/smdl.yml',
                        type=str,
                        help="An optional config file"
                        " to be loaded")
    args = parser.parse_args()

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    if not os.path.exists('datasets'):
        os.makedirs('datasets')
    if not os.path.exists('output'):
        os.makedirs('output')

    timestamp = time.strftime("%m%d_%H%M%S")
    cfg.timestamp = timestamp

    output_dir = './output/' + cfg.run_label + '_' + cfg.timestamp
    cfg.output_dir = output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        os.makedirs(output_dir + '/models')
        os.makedirs(output_dir + '/plots')
        os.makedirs(output_dir + '/logs')
        os.makedirs(output_dir + '/accuracies')

    logging.basicConfig(filename=output_dir + '/logs/smdl_' + timestamp +
                        '.log',
                        level=logging.DEBUG,
                        format='%(levelname)s:\t%(message)s')

    log(pprint.pformat(cfg))

    gpu_list = cfg.gpu_ids.split(',')
    gpus = [int(iter) for iter in gpu_list]
    torch.cuda.set_device(gpus[0])
    torch.backends.cudnn.benchmark = True

    if cfg.seed != 0:
        np.random.seed(cfg.seed)
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(cfg.seed)

    submodular_training(gpus)
コード例 #5
0
    def __init__(self):
        # dnn
        cfg_from_file(cfg_file)
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = cfg.RPN.FIXED = True
        cfg.RPN.LOC_XZ_FINE = False
        self.pc_roi = [[-25, 25], [-3, 2], [-25, 25]]
        self.down_sample = {'axis': int(0), 'depth': self.pc_roi[0][1] / 2}  # [axis,depth]
        self.mode = 'TEST'
        with torch.no_grad():
            self.model = PointRCNN(num_classes=2, use_xyz=True, mode=self.mode)
            self.model.cuda()
            self.model.eval()
            load_checkpoint(model=self.model, optimizer=None, filename=pointrcnn_weight)

        # ros
        self.pc_sub = rospy.Subscriber(pc_topic, PointCloud2, self.pc_cb, queue_size=1, buff_size=2 ** 24)

        self.pc_pub = rospy.Publisher(pack_name + "/networks_input", PointCloud2, queue_size=1) if is_viz else None
        self.mk_pub = rospy.Publisher(pack_name + "/networks_output", MarkerArray, queue_size=1) if is_viz else None
        self.Tr_velo_kitti_cam = np.array([0.0, - 1.0, 0.0, 0.0,
                                           0.0, 0.0, -1.0, 1.5,
                                           1.0, 0.0, 0.0, 0.0,
                                           0.0, 0.0, 0.0, 1.0]).reshape(4, 4) if is_tf else np.identity(4)
コード例 #6
0
ファイル: eval_rcnn.py プロジェクト: zuoym15/PointRCNN
        logger=logger)

    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=args.workers,
                             collate_fn=test_set.collate_batch)

    return test_loader


if __name__ == "__main__":
    # merge config and log to file
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    cfg.TAG = os.path.splitext(os.path.basename(args.cfg_file))[0]

    if args.eval_mode == 'rpn':
        cfg.RPN.ENABLED = True
        cfg.RCNN.ENABLED = False
        root_result_dir = os.path.join('../', 'output', 'rpn', cfg.TAG)
        ckpt_dir = os.path.join('../', 'output', 'rpn', cfg.TAG, 'ckpt')
    elif args.eval_mode == 'rcnn':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = cfg.RPN.FIXED = True
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
        ckpt_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG, 'ckpt')
    elif args.eval_mode == 'rcnn_offline':
コード例 #7
0
        snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot')
        return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth")


def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Image Captioning')
    parser.add_argument('--folder', dest='folder', default=None, type=str)
    parser.add_argument("--resume", type=int, default=-1)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print('Called with args:')
    print(args)
    if args.folder is not None:
        cfg_from_file(os.path.join(args.folder, 'config.yml'))
    cfg.ROOT_DIR = args.folder

    tester = Tester(args)
    tester.eval(args.resume)
コード例 #8
0
ファイル: api.py プロジェクト: nevermore3/plateRecognition
 def __init__(self, cfg_filepath=".\\cfgs\\easypr.yml"):
     cfg_from_file(cfg_filepath)
     self._cfg = cfg
コード例 #9
0
    parser.add_argument("--test_bs", type=int, default=-1)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print('Called with args:')
    print(args)
    args.resume = -1
    if args.config is not None:
        cfg_from_file(args.config)
    os.environ["TORCH_HOME"] = cfg.MODEL.PRETRAINED_PATH
    if args.multi_crop:
        cfg.AUG.MULTI_CROP_TEST = True
    if args.test_crop_size > 0:
        cfg.AUG.TEST_CROP = [args.test_crop_size, args.test_crop_size]
    if args.test_resize_size > 0:
        cfg.AUG.RESIZE = [args.test_resize_size, args.test_resize_size]
    if args.test_bs > 0:
        cfg.TEST.BATCH_SIZE = args.test_bs
    trainer = Trainer(args)
    netG_dict = torch.load(args.netG_model_path,
                           map_location=lambda storage, loc: storage)
    current_state = trainer.netG.state_dict()
    keys = list(current_state.keys())
    for key in keys:
コード例 #10
0
                                    list(cfg.TRAIN.MOMS), cfg.TRAIN.DIV_FACTOR,
                                    cfg.TRAIN.PCT_START)
    else:
        lr_scheduler = lr_sched.LambdaLR(optimizer,
                                         lr_lbmd,
                                         last_epoch=last_iter)

    bnm_scheduler = train_utils.BNMomentumScheduler(model,
                                                    bnm_lmbd,
                                                    last_epoch=last_iter)
    return lr_scheduler, bnm_scheduler


if __name__ == "__main__":
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file + 'weaklyRPN.yaml')
        cfg_from_file(args.cfg_file + 'weaklyRCNN.yaml')
        cfg_from_file(args.cfg_file + 'weaklyIOUN.yaml')
    cfg.TAG = os.path.splitext(os.path.basename(args.cfg_file))[0]

    cfg.RCNN.ENABLED = False
    cfg.IOUN.ENABLED = True
    cfg.RPN.ENABLED = cfg.RPN.FIXED = False
    root_result_dir = os.path.join('../', 'output', 'ioun', cfg.TAG + exp_id)

    if args.output_dir is not None:
        root_result_dir = args.output_dir
    os.makedirs(root_result_dir, exist_ok=True)

    log_file = os.path.join(root_result_dir, 'log_train.txt')
    logger = create_logger(log_file)
コード例 #11
0
        snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot')
        return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth")

def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Image Captioning')
    parser.add_argument('--folder', dest='folder', default=None, type=str)
    parser.add_argument("--resume", type=int, default=-1)
    parser.add_argument('--config', default='config.yml')

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    print('Called with args:')
    print(args)

    if args.folder is not None:
        cfg_from_file(os.path.join(args.folder, args.config))
    cfg.ROOT_DIR = args.folder

    tester = Tester(args)
    tester.eval(args.resume)
コード例 #12
0
ファイル: rpn.py プロジェクト: NehilDanis/PointRCNN
            'rpn_cls': rpn_cls,
            'rpn_reg': rpn_reg,
            'backbone_xyz': backbone_xyz,
            'backbone_features': backbone_features
        }

        return ret_dict


if __name__ == '__main__':
    import ipdb
    from lib.config import cfg, cfg_from_file, save_config_to_file, cfg_from_list
    import torch
    from lib.net.rcnn_net import RCNNNet
    cfg_file = 'tools/cfgs/default.yaml'
    cfg_from_file(cfg_file)
    cfg.TAG = os.path.splitext(os.path.basename(cfg_file))[0]

    train_mode = 'rcnn'
    if train_mode == 'rpn':
        cfg.RPN.ENABLED = True
        cfg.RCNN.ENABLED = False
        root_result_dir = os.path.join('../', 'output', 'rpn', cfg.TAG)
    elif train_mode == 'rcnn':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = cfg.RPN.FIXED = True
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
    elif train_mode == 'rcnn_offline':
        cfg.RCNN.ENABLED = True
        cfg.RPN.ENABLED = False
        root_result_dir = os.path.join('../', 'output', 'rcnn', cfg.TAG)
コード例 #13
0
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print('Called with args:')
    print(args)

    if args.folder is not None:
        if not os.path.exists(args.config):
            config_path = os.path.join(args.folder, args.config)
        else:
            config_path = args.config
        cfg_from_file(config_path)
    cfg.ROOT_DIR = args.folder

    if args.test_raw_image:
        cfg.RAW_DATA_LOADER.TEST_IMG_DIR = os.path.join(
            args.test_dir, 'images')
        cfg.RAW_DATA_LOADER.TEST_ATT_FEATS = os.path.join(
            args.test_dir, 'vg/features')
        cfg.RAW_DATA_LOADER.TEST_PROCESSEDIMG_DIR = os.path.join(
            args.test_dir, 'vg/images')

    tester = Tester(args)
    tester.eval(args.resume)
コード例 #14
0
def main():
    """
    This function:
        - Creates a random class-list that will be used by all the methods,
        - Saves it to file,
        - Creates yaml files with run configurations,
        - Executes each of the runs.
    :return: None
    """

    # Retrieving the arguments
    parser = argparse.ArgumentParser(
        description="SMILe: SubModular Incremental Learning")
    parser.add_argument("--cfg",
                        dest='cfg_file',
                        default='./config/smile.yml',
                        type=str,
                        help="An optional config file"
                        " to be loaded")
    args = parser.parse_args()

    # Updating the configuration object
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    # Creating the class list
    class_list = []
    classes = range(0, cfg.dataset.total_num_classes)
    for i in range(cfg.repeat_rounds):
        classes = np.random.permutation(classes)
        class_list.append(classes)

    # Saving the class list
    output_dir = './run_sandbox/'
    cfg.output_dir = output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(output_dir + '/class_list.pkl', 'w') as f:
        pickle.dump(class_list, f)

    # Start Smiling :)
    gpus = [5, 6, 7]
    current_run_label = cfg['run_label']
    cfg['load_class_list_from_file'] = True

    # YAML for Random Run
    cfg['gpu_ids'] = str(gpus[1])
    cfg['run_label'] = 'random_' + current_run_label
    cfg['sampling_strategy'] = 'random'
    with open(output_dir + '/random.yml', 'w') as outfile:
        yaml.dump(dict(cfg), outfile, default_flow_style=False)

    # YAML for SubModular Run
    cfg['gpu_ids'] = str(gpus[0])
    cfg['run_label'] = 'submodular_' + current_run_label
    cfg['sampling_strategy'] = 'submodular'
    with open(output_dir + '/submodular.yml', 'w') as outfile:
        yaml.dump(dict(cfg), outfile, default_flow_style=False)

    # YAML for Full-Dataset Run
    cfg['gpu_ids'] = str(gpus[2])
    cfg['run_label'] = 'full_dataset_' + current_run_label
    cfg['use_all_exemplars'] = True
    with open(output_dir + '/full_dataset.yml', 'w') as outfile:
        yaml.dump(dict(cfg), outfile, default_flow_style=False)

    p1 = subprocess.Popen([
        '/raid/joseph/il/opy27/bin/python', 'smile.py', '--cfg',
        './run_sandbox/full_dataset.yml'
    ],
                          stdout=subprocess.PIPE)
    p2 = subprocess.Popen([
        '/raid/joseph/il/opy27/bin/python', 'smile.py', '--cfg',
        './run_sandbox/random.yml'
    ],
                          stdout=subprocess.PIPE)
    p3 = subprocess.Popen([
        '/raid/joseph/il/opy27/bin/python', 'smile.py', '--cfg',
        './run_sandbox/submodular.yml'
    ],
                          stdout=subprocess.PIPE)

    print 'All processes started normally.'

    p1.communicate()
    p2.communicate()
    p3.communicate()

    print 'Finishing Launcher.'
コード例 #15
0
    dataset = KittiDataset('/raid/meng/Dataset/Kitti/object',
                           split=split,
                           noise='label_noise')
else:
    dataset = KittiDataset('/raid/meng/Dataset/Kitti/object', split=split)
data_loader = DataLoader(dataset,
                         batch_size=1,
                         shuffle=False,
                         pin_memory=True,
                         num_workers=1)
save_dir = '/raid/meng/Dataset/Kitti/object/training/boxes_dataset'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
ckpt_file = '/raid/meng/Pointcloud_Detection/PointRCNN4_weak/output/rpn/weaklyRPN0500/410_floss03_8000/ckpt/checkpoint_iter_07620.pth'
cfg_from_file(
    '/raid/meng/Pointcloud_Detection/PointRCNN1.1_weak/tools/cfgs/weaklyRPN.yaml'
)

cfg.RPN.SCORE_THRESH = 0.1
PROP_DIST = 0.3
BACKGROUND_ADDING = False
BACK_THRESH = 0.3
COSINE_DISTANCE = False
COS_THRESH = 0.3

from lib.net.point_rcnn import PointRCNN
model = PointRCNN(num_classes=data_loader.dataset.num_class,
                  use_xyz=True,
                  mode='TEST')
model.cuda()
checkpoint = torch.load(ckpt_file)
コード例 #16
0
def main():
    ### set divice
    if not args.cuda:
        cfg.TRAIN.DEVICE = torch.device('cpu')
    else :
        assert torch.cuda.is_available(), "Not enough GPU"
        #assert d < torch.cuda.device_count(), "Not enough GPU"
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(ids) for ids in args.device_ids])
        torch.backends.cudnn.benchmark=True
        cfg.CUDA = True
        cfg.TRAIN.DEVICE = torch.device('cuda:0')
        print("Let's use", torch.cuda.device_count(), "GPUs!")

    ### set config 
    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    cfg_from_args(args)
    print_cfg(cfg)


    # assert_and_infer_cfg()



    if not cfg.TRAIN.NO_SAVE:

        run_folder = create_folder_for_run(cfg.TRAIN.RUNS_FOLDER)
        logging.basicConfig(level=logging.INFO,
                            format='%(message)s',
                            handlers=[
                                logging.FileHandler(os.path.join(run_folder, f'{cfg.TRAIN.NAME}.log')),
                                logging.StreamHandler(sys.stdout)
                            ])

        with open(os.path.join(run_folder, 'config_and_args.pkl'), 'wb') as f:
            blob = {'cfg': yaml.dump(cfg), 'args': args}
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        with open(os.path.join(run_folder, 'args.txt'), 'w') as f:
            for item in vars(args):
                f.write(item+":"+str(getattr(args,item))+'\n')
        logging.info('×' * 40)

        shutil.copy(args.cfg_file, os.path.join(run_folder, cfg.TRAIN.NAME) + '_cfg')
        logging.info('save config and args in runs folder:\n %s' % run_folder)
        # if args.use_tfboard:
        #     tblogger = SummaryWriter(run_folder)

    else:
        logging.basicConfig(level=logging.INFO)
        # logger = logging.getLogger(__name__)
    # print('args:')
    # logging.info(pprint.pformat(vars(args)))
    # print('cfg:')
    # logging.info(yaml.dump(cfg.TRAIN))


    loader = Random_DataLoader()

    model = get_Model()
    
    model = nn.DataParallel(model)
    model.to(cfg.TRAIN.DEVICE)

    logging.info(model)
    
    Trainer(loader, model)
コード例 #17
0
ファイル: main.py プロジェクト: zehuiw/text2shape
def modify_args(args):
    """Modify the default config based on the command line arguments.
    """
    # modify default config if requested
    if args.cfg_files is not None:
        for cfg_file in args.cfg_files:
            cfg_from_file(cfg_file)
    randomize = args.randomize
    if args.test:  # Always randomize in test phase
        randomize = True
    if not randomize:
        np.random.seed(cfg.CONST.RNG_SEED)

    # NOTE: Unfortunately order matters here
    if args.lba_only is True:
        cfg_from_list(['LBA.COSINE_DIST', False])
    if args.metric_learning_only is True:
        cfg_from_list(['LBA.NO_LBA', True])
    if args.non_inverted_loss is True:
        cfg_from_list(['LBA.INVERTED_LOSS', False])
    if args.dataset is not None:
        cfg_from_list(['CONST.DATASET', args.dataset])
    if args.lba_mode is not None:
        cfg_from_list(['LBA.MODEL_TYPE', args.lba_mode])
    if args.lba_test_mode is not None:
        cfg_from_list(['LBA.TEST_MODE', args.lba_test_mode])
        # cfg_from_list(['LBA.N_CAPTIONS_PER_MODEL', 1])  # NOTE: Important!
    if args.shapenet_ct_classifier is True:
        cfg_from_list(
            ['CONST.SHAPENET_CT_CLASSIFIER', args.shapenet_ct_classifier])
    if args.visit_weight is not None:
        cfg_from_list(['LBA.VISIT_WEIGHT', args.visit_weight])
    if args.lba_unnormalize is True:
        cfg_from_list(['LBA.NORMALIZE', False])
    if args.improved_wgan is True:
        cfg_from_list(['CONST.IMPROVED_WGAN', args.improved_wgan])
    if args.synth_embedding is True:
        cfg_from_list(['CONST.SYNTH_EMBEDDING', args.synth_embedding])
    if args.all_tuples is True:
        cfg_from_list(['CONST.TEST_ALL_TUPLES', args.all_tuples])
    if args.reed_classifier is True:
        cfg_from_list(['CONST.REED_CLASSIFIER', args.reed_classifier])
    if args.noise_dist is not None:
        cfg_from_list(['GAN.NOISE_DIST', args.noise_dist])
    if args.uniform_max is not None:
        cfg_from_list(['GAN.NOISE_UNIF_ABS_MAX', args.uniform_max])
    if args.num_critic_steps is not None:
        cfg_from_list(['WGAN.NUM_CRITIC_STEPS', args.num_critic_steps])
    if args.intense_training_freq is not None:
        cfg_from_list(
            ['WGAN.INTENSE_TRAINING_FREQ', args.intense_training_freq])
    if args.match_loss_coeff is not None:
        cfg_from_list(['WGAN.MATCH_LOSS_COEFF', args.match_loss_coeff])
    if args.fake_match_loss_coeff is not None:
        cfg_from_list(
            ['WGAN.FAKE_MATCH_LOSS_COEFF', args.fake_match_loss_coeff])
    if args.fake_mismatch_loss_coeff is not None:
        cfg_from_list(
            ['WGAN.FAKE_MISMATCH_LOSS_COEFF', args.fake_mismatch_loss_coeff])
    if args.gp_weight is not None:
        cfg_from_list(['WGAN.GP_COEFF', args.gp_weight])
    if args.text2text_weight is not None:
        cfg_from_list(['WGAN.TEXT2TEXT_WEIGHT', args.text2text_weight])
    if args.shape2shape_weight is not None:
        cfg_from_list(['WGAN.SHAPE2SHAPE_WEIGHT', args.shape2shape_weight])
    if args.learning_rate is not None:
        cfg_from_list(['TRAIN.LEARNING_RATE', args.learning_rate])
    if args.critic_lr_multiplier is not None:
        cfg_from_list(
            ['GAN.D_LEARNING_RATE_MULTIPLIER', args.critic_lr_multiplier])
    if args.decay_steps is not None:
        cfg_from_list(['TRAIN.DECAY_STEPS', args.decay_steps])
    if args.queue_capacity is not None:
        cfg_from_list(['CONST.QUEUE_CAPACITY', args.queue_capacity])
    if args.n_minibatch_test is not None:
        cfg_from_list(['CONST.N_MINIBATCH_TEST', args.n_minibatch_test])
    if args.noise_size is not None:
        cfg_from_list(['GAN.NOISE_SIZE', args.noise_size])
    if args.batch_size is not None:
        cfg_from_list(['CONST.BATCH_SIZE', args.batch_size])
    if args.summary_freq is not None:
        cfg_from_list(['TRAIN.SUMMARY_FREQ', args.summary_freq])
    if args.num_epochs is not None:
        cfg_from_list(['TRAIN.NUM_EPOCHS', args.num_epochs])
    if args.model is not None:
        cfg_from_list(['NETWORK', args.model])
    if args.optimizer is not None:
        cfg_from_list(['TRAIN.OPTIMIZER', args.optimizer])
    if args.critic_optimizer is not None:
        cfg_from_list(['GAN.D_OPTIMIZER', args.critic_optimizer])
    if args.ckpt_path is not None:
        cfg_from_list(['DIR.CKPT_PATH', args.ckpt_path])
    if args.lba_ckpt_path is not None:
        cfg_from_list(['END2END.LBA_CKPT_PATH', args.lba_ckpt_path])
    if args.val_ckpt_path is not None:
        cfg_from_list(['DIR.VAL_CKPT_PATH', args.val_ckpt_path])
    if args.log_path is not None:
        cfg_from_list(['DIR.LOG_PATH', args.log_path])
    if args.augment_max is not None:
        cfg_from_list(['TRAIN.AUGMENT_MAX', args.augment_max])
    if args.test:
        cfg_from_list(['TRAIN.AUGMENT_MAX', 0])
        cfg_from_list(['CONST.BATCH_SIZE', 1])
        cfg_from_list(['LBA.N_CAPTIONS_PER_MODEL', 1])  # NOTE: Important!
        cfg_from_list(['LBA.N_PRIMITIVE_SHAPES_PER_CATEGORY',
                       1])  # NOTE: Important!
    if args.test_npy:
        cfg_from_list(['CONST.BATCH_SIZE', 1])

    # To overwrite default variables, put the set_cfgs after all argument initializations
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
コード例 #18
0
def main():

    args = parse_args()
    cfg.DATA_DIR = args.data_dir
    cfg.CONTEXT_FUSION = args.context_fusion
    # c_time = time.strftime('%m%d_%H%M', time.localtime())
    # if not os.path.exists(cfg.LOG_DIR):
    #     os.makedirs(cfg.LOG_DIR)
    # file_handler = logging.FileHandler(pjoin(cfg.LOG_DIR,
    #                                          args.network_name + '_%s.txt' % c_time))
    # logging.getLogger().addHandler(file_handler)

    print('------ called with args: -------')
    pprint.pprint(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    if cfg.INIT_BY_GLOVE and cfg.KEEP_AS_GLOVE_DIM:
        cfg.EMBED_DIM = cfg.GLOVE_DIM
    else:
        cfg.EMBED_DIM = args.embed_dim

    print("Using config:")
    pprint.pprint(cfg)

    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
        tf.set_random_seed(cfg.RNG_SEED)

    imdb, roidb = GetRoidb(args.imdb_name)

    output_dir = get_output_dir(imdb, args.tag)
    print("output will be saved to `{:s}`".format(output_dir))

    # tensorboard directory where the summaries are saved during training
    tb_dir = get_output_tb_dir(imdb, args.tag)
    print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

    # also add validation set, but with no flipping image
    orgflip = cfg.TRAIN.USE_FLIPPED
    cfg.TRAIN.USE_FLIPPED = False
    _, valroidb = GetRoidb(args.imdbval_name)
    cfg.TRAIN.USE_FLIPPED = orgflip

    # load network
    if args.net == 'vgg16':
        net = vgg16()
    elif args.net == 'res50':
        net = resnetv1(num_layers=50)
    elif args.net == 'res101':
        net = resnetv1(num_layers=101)
    elif args.net == 'res152':
        net = resnetv1(num_layers=152)
    else:
        raise NotImplementedError

    if args.weights and not args.weights.endswith('.ckpt'):
        try:
            ckpt = tf.train.get_checkpoint_state(args.weights)
            pretrained_model = ckpt.model_checkpoint_path
        except:
            raise ValueError("NO checkpoint found in {}".format(args.weights))
    else:
        pretrained_model = args.weights

    # TODO: "imdb" may not be useful during training
    train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
              pretrained_model=pretrained_model,
              max_iters=args.max_iters)
コード例 #19
0
def main():
    args = parse_args()

    # c_time = time.strftime('%m%d_%H%M', time.localtime())
    # if not os.path.exists(cfg.LOG_DIR):
    #     os.makedirs(cfg.LOG_DIR)
    # file_handler = logging.FileHandler(pjoin(cfg.LOG_DIR,
    #                                          args.network_name + '_%s.txt' % c_time))
    # logging.getLogger().addHandler(file_handler)

    print('------ called with args: -------')
    pprint.pprint(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print("runing with LIMIT_RAM: {}".format(cfg.LIMIT_RAM))

    print("Using config:")
    pprint.pprint(cfg)

    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
        tf.set_random_seed(cfg.RNG_SEED)

    if not cfg.LIMIT_RAM:
        imdb, roidb = combined_roidb(args.imdb_name)
    else:
        imdb, roidb = get_roidb_limit_ram(args.imdb_name)

    output_dir = get_output_dir(imdb, args.tag)
    print("output will be saved to `{:s}`".format(output_dir))

    # tensorboard directory where the summaries are saved during training
    tb_dir = get_output_tb_dir(imdb, args.tag)
    print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

    # also add validation set, but with no flipping image
    orgflip = cfg.TRAIN.USE_FLIPPED
    cfg.TRAIN.USE_FLIPPED = False
    if not cfg.LIMIT_RAM:
        _, valroidb = combined_roidb(args.imdbval_name)
    else:
        _, valroidb = get_roidb_limit_ram(args.imdbval_name)
    cfg.TRAIN.USE_FLIPPED = orgflip

    # load network
    if args.net == 'vgg16':
        net = vgg16()
    elif args.net == 'res50':
        net = resnetv1(num_layers=50)
    elif args.net == 'res101':
        net = resnetv1(num_layers=101)
    elif args.net == 'res152':
        net = resnetv1(num_layers=152)
    else:
        raise NotImplementedError

    # TODO: "imdb" may not be useful during training
    train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)