def main():
    """
    Main body of script.
    """
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    batch_size = prepare_ke_context(num_gpus=args.num_gpus,
                                    batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip())
    num_classes = net.classes if hasattr(net, "classes") else 1000
    input_image_size = net.in_size if hasattr(
        net, "in_size") else (args.input_size, args.input_size)

    train_data, val_data = get_data_rec(
        rec_train=args.rec_train,
        rec_train_idx=args.rec_train_idx,
        rec_val=args.rec_val,
        rec_val_idx=args.rec_val_idx,
        batch_size=batch_size,
        num_workers=args.num_workers,
        input_image_size=input_image_size,
        resize_inv_factor=args.resize_inv_factor)
    train_gen = get_data_generator(data_iterator=train_data,
                                   num_classes=num_classes)
    val_gen = get_data_generator(data_iterator=val_data,
                                 num_classes=num_classes)

    net = prepare_trainer(net=net,
                          optimizer_name=args.optimizer_name,
                          momentum=args.momentum,
                          lr=args.lr,
                          num_gpus=args.num_gpus,
                          state_file_path=args.resume_state)

    train_net(net=net,
              train_gen=train_gen,
              val_gen=val_gen,
              train_num_examples=1281167,
              val_num_examples=50048,
              num_epochs=args.num_epochs,
              checkpoint_filepath=os.path.join(
                  args.save_dir, "imagenet_{}.h5".format(args.model)),
              start_epoch1=args.start_epoch)
예제 #2
0
def main():
    """
    Main body of script.
    """
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    batch_size = prepare_ke_context(
        num_gpus=args.num_gpus,
        batch_size=args.batch_size)

    net = prepare_model(
        model_name=args.model,
        use_pretrained=args.use_pretrained,
        pretrained_model_file_path=args.resume.strip())
    num_classes = net.classes if hasattr(net, "classes") else 1000
    input_image_size = net.in_size if hasattr(net, "in_size") else (args.input_size, args.input_size)

    train_data, val_data = get_data_rec(
        rec_train=args.rec_train,
        rec_train_idx=args.rec_train_idx,
        rec_val=args.rec_val,
        rec_val_idx=args.rec_val_idx,
        batch_size=batch_size,
        num_workers=args.num_workers,
        input_image_size=input_image_size,
        resize_inv_factor=args.resize_inv_factor,
        only_val=True)
    val_gen = get_data_generator(
        data_iterator=val_data,
        num_classes=num_classes)

    val_size = 50000
    assert (args.use_pretrained or args.resume.strip())
    test(
        net=net,
        val_gen=val_gen,
        val_size=val_size,
        batch_size=batch_size,
        num_gpus=args.num_gpus,
        calc_weight_count=True,
        extended_log=True)
예제 #3
0
def main():
    args = parse_args()

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    batch_size = prepare_ke_context(num_gpus=args.num_gpus,
                                    batch_size=args.batch_size)

    num_classes = 1000
    net = prepare_model(model_name=args.model,
                        classes=num_classes,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip())

    train_data, val_data = get_data_rec(rec_train=args.rec_train,
                                        rec_train_idx=args.rec_train_idx,
                                        rec_val=args.rec_val,
                                        rec_val_idx=args.rec_val_idx,
                                        batch_size=batch_size,
                                        num_workers=args.num_workers)
    val_gen = get_data_generator(data_iterator=val_data,
                                 num_classes=num_classes)

    val_size = 50000
    assert (args.use_pretrained or args.resume.strip())
    test(net=net,
         val_gen=val_gen,
         val_size=val_size,
         batch_size=batch_size,
         num_gpus=args.num_gpus,
         calc_weight_count=True,
         extended_log=True)
예제 #4
0
def main():
    args = parse_args()
    args.seed = init_rand(seed=args.seed)

    _, log_file_exist = initialize_logging(
        logging_dir_path=args.save_dir,
        logging_file_name=args.logging_file_name,
        script_args=args,
        log_packages=args.log_packages,
        log_pip_packages=args.log_pip_packages)

    batch_size = prepare_ke_context(num_gpus=args.num_gpus,
                                    batch_size=args.batch_size)

    net = prepare_model(model_name=args.model,
                        use_pretrained=args.use_pretrained,
                        pretrained_model_file_path=args.resume.strip())
    num_classes = net.classes if hasattr(net, 'classes') else 1000
    input_image_size = net.in_size if hasattr(
        net, 'in_size') else (args.input_size, args.input_size)

    train_data, val_data = get_data_rec(
        rec_train=args.rec_train,
        rec_train_idx=args.rec_train_idx,
        rec_val=args.rec_val,
        rec_val_idx=args.rec_val_idx,
        batch_size=batch_size,
        num_workers=args.num_workers,
        input_image_size=input_image_size,
        resize_inv_factor=args.resize_inv_factor)
    train_gen = get_data_generator(data_iterator=train_data,
                                   num_classes=num_classes)
    val_gen = get_data_generator(data_iterator=val_data,
                                 num_classes=num_classes)

    net = prepare_trainer(net=net,
                          optimizer_name=args.optimizer_name,
                          momentum=args.momentum,
                          lr=args.lr,
                          num_gpus=args.num_gpus,
                          state_file_path=args.resume_state)

    # if args.save_dir and args.save_interval:
    #     lp_saver = TrainLogParamSaver(
    #         checkpoint_file_name_prefix='imagenet_{}'.format(args.model),
    #         last_checkpoint_file_name_suffix="last",
    #         best_checkpoint_file_name_suffix=None,
    #         last_checkpoint_dir_path=args.save_dir,
    #         best_checkpoint_dir_path=None,
    #         last_checkpoint_file_count=2,
    #         best_checkpoint_file_count=2,
    #         checkpoint_file_save_callback=save_params,
    #         checkpoint_file_exts=('.h5', '.h5states'),
    #         save_interval=args.save_interval,
    #         num_epochs=args.num_epochs,
    #         param_names=['Val.Top1', 'Train.Top1', 'Val.Top5', 'Train.Loss', 'LR'],
    #         acc_ind=2,
    #         # bigger=[True],
    #         # mask=None,
    #         score_log_file_path=os.path.join(args.save_dir, 'score.log'),
    #         score_log_attempt_value=args.attempt,
    #         best_map_log_file_path=os.path.join(args.save_dir, 'best_map.log'))
    # else:
    #     lp_saver = None

    train_net(net=net,
              train_gen=train_gen,
              val_gen=val_gen,
              train_num_examples=1281167,
              val_num_examples=50048,
              num_epochs=args.num_epochs,
              checkpoint_filepath=os.path.join(
                  args.save_dir, 'imagenet_{}.h5'.format(args.model)),
              start_epoch1=args.start_epoch)