コード例 #1
0
ファイル: agents.py プロジェクト: keimikami/dqn
    def __init__(self, num_actions):
        self.num_actions = num_actions
        self.step = 0
        self.replay_buffer = ReplayBuffer(NUM_REPLAY_MEMORY)
        self.epsilon = INITIAL_EPSILON
        self.epsilon_step = (INITIAL_EPSILON -
                             FINAL_EPSILON) / EXPLORATION_STEPS

        self.model = build_network(FRAME_WIDTH, FRAME_HEIGHT, 1,
                                   self.num_actions)
        self.target_network = build_network(FRAME_WIDTH, FRAME_HEIGHT, 1,
                                            self.num_actions)
コード例 #2
0
def get_model(args,
              num_classes,
              test=False,
              channel_last=False,
              mixup=None,
              channels=4,
              spatial_size=224,
              label_smoothing=0,
              ctx_for_loss=None):
    """
    Create computation graph and variables.
    """
    from models import build_network
    from utils.loss import softmax_cross_entropy_with_label_smoothing

    if hasattr(spatial_size, '__len__'):
        assert len(spatial_size) == 2, \
            f'Spatial size must be a scalar or a tuple of two ints. Given {spatial_size}'
        spatial_shape = tuple(spatial_size)
    else:
        spatial_shape = (spatial_size, spatial_size)
    if channel_last:
        image = nn.Variable(
            (args.batch_size, spatial_shape[0], spatial_shape[1], channels))
    else:
        image = nn.Variable((args.batch_size, channels) + spatial_shape)
    label = nn.Variable([args.batch_size, 1])

    in_image = image
    in_label = label
    if mixup is not None:
        image, label = mixup.mix_data(image, label)
    pred, hidden = build_network(image,
                                 num_classes,
                                 args.arch,
                                 test=test,
                                 channel_last=channel_last)
    pred.persistent = True

    def define_loss(pred, in_label, label, label_smoothing):
        loss = F.mean(
            softmax_cross_entropy_with_label_smoothing(pred, label,
                                                       label_smoothing))
        error = F.sum(F.top_n_error(pred, in_label, n=1))
        return loss, error

    # Use specified context if possible.
    # We use it when we pass float32 context to avoid nan issue
    if ctx_for_loss is not None:
        with nn.context_scope(ctx_for_loss):
            loss, error = define_loss(pred, in_label, label, label_smoothing)
    else:
        loss, error = define_loss(pred, in_label, label, label_smoothing)
    Model = namedtuple('Model',
                       ['image', 'label', 'pred', 'loss', 'error', 'hidden'])
    return Model(in_image, in_label, pred, loss, error, hidden)
コード例 #3
0
 def build_network(self):
     ######################
     # BUILD TARGET ASSIGNER
     ######################
     bv_range = self.voxel_generator.point_cloud_range[[0, 1, 3, 4]]
     box_coder = build_box_coder(self.config.BOX_CODER)
     target_assigner_cfg = self.config.TARGET_ASSIGNER
     target_assigner = build_target_assigner(target_assigner_cfg, bv_range,
                                             box_coder)
     ######################
     # BUILD NET
     ######################
     self.model_cfg.XAVIER = True
     net = build_network(self.model_cfg, self.voxel_generator,
                         target_assigner)
     return net
コード例 #4
0
def main():
    args = get_args()

    # Setup
    from nnabla.ext_utils import get_extension_context
    if args.context is None:
        extension_module = "cudnn"  # TODO: Hard coded!!!
    else:
        extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Load parameters
    channel_last, channels = load_parameters_and_config(
        args.weights, args.type_config)
    logger.info('Parameter configuration is deduced as:')
    logger.info(f'* channel_last={channel_last}')
    logger.info(f'* channels={channels}')

    # Read image
    image = read_image_with_preprocess(args.input_image,
                                       args.norm_config,
                                       channel_last=channel_last,
                                       channels=channels,
                                       spatial_size=args.spatial_size)
    img = nn.NdArray.from_numpy_array(image)

    # Perform inference
    from models import build_network
    num_classes = args.num_classes
    pred, _ = build_network(img,
                            num_classes,
                            args.arch,
                            test=True,
                            channel_last=channel_last)
    prob = F.softmax(pred)
    top5_index = F.sort(prob, reverse=True, only_index=True)[:, :5]

    # Get and print result
    labels = read_labels(args.labels)
    logger.info(f'Top-5 prediction:')
    for i in top5_index.data[0]:
        logger.info(
            f'* {int(i)} {labels[int(i)]}: {prob.data[0, int(i)] * 100:.2f}')
コード例 #5
0
ファイル: train.py プロジェクト: zxsted/async-rl-noreward
    def __init__(self, num_action, n_step=8, discount=0.99):

        self.value_net, self.policy_net, self.load_net, _ = build_network(observation_shape,
                                                                          num_action)
        self.icm = build_icm_model(screen, (num_action,))

        self.value_net.compile(optimizer='rmsprop', loss='mse')
        self.policy_net.compile(optimizer='rmsprop', loss='mse')
        self.load_net.compile(optimizer='rmsprop', loss='mse', loss_weights=[0.5, 1.])  # dummy loss
        self.icm.compile(optimizer="rmsprop", loss=lambda y_true, y_pred: y_pred)

        self.num_action = num_action
        self.observations = np.zeros(observation_shape)
        self.last_observations = np.zeros_like(self.observations)

        self.n_step_data = deque(maxlen=n_step)
        self.n_step = n_step
        self.discount = discount
コード例 #6
0
ファイル: train.py プロジェクト: zxsted/async-rl-noreward
    def __init__(self, action_space, batch_size=32, swap_freq=200):
        from keras.optimizers import RMSprop
        _, _, self.train_net, advantage = build_network(observation_shape,
                                                        action_space.num_discrete_space)
        self.icm = build_icm_model(screen, (action_space.num_discrete_space,))

        self.train_net.compile(optimizer=RMSprop(epsilon=0.1, rho=0.99),
                               loss=[value_loss(), policy_loss(advantage, args.beta)])
        self.icm.compile(optimizer="rmsprop", loss=lambda y_true, y_pred: y_pred)

        self.pol_loss = deque(maxlen=25)
        self.val_loss = deque(maxlen=25)
        self.values = deque(maxlen=25)
        self.swap_freq = swap_freq
        self.swap_counter = self.swap_freq
        self.batch_size = batch_size
        self.unroll = np.arange(self.batch_size)
        self.targets = np.zeros((self.batch_size, action_space.num_discrete_space))
        self.counter = 0
コード例 #7
0
y_test = to_categorical(y_test, num_classes)

normalization = None
if args.normalize:
    mean = X_train.mean(axis=0)
    std = X_train.std(axis=0)
    std[std == 0] = 1.0  # Do not modify points where variance is null
    normalization = lambda xi: Lambda(lambda image, mu, std:
                                      (image - mu) / std,
                                      arguments={
                                          'mu': mean,
                                          'std': std
                                      })(xi)

# We build the requested model
model = models.build_network(args.model, input_shape, num_classes,
                             normalization, args.dropout, args.L2)
model.summary()

# Callbacks


def generate_unique_logpath(logdir, raw_run_name):
    i = 0
    while (True):
        run_name = raw_run_name + "-" + str(i)
        log_path = os.path.join(logdir, run_name)
        if not os.path.isdir(log_path):
            return log_path
        i = i + 1

コード例 #8
0
ファイル: main.py プロジェクト: npalff/deeplearning-lectures
def train(args):
    """
    Training of the algorithm
    """

    if (not args.model):
        print("--model is required for training. Call with -h for help")
        sys.exit(-1)

    train_data, val_data, test_data, normalization, input_shape, num_classes = data.get_data(
        args.normalize, args.data_augment)

    # We build the requested model
    model = models.build_network(args.model, input_shape, num_classes,
                                 normalization, args.dropout, args.L2)
    model.summary()

    # Callbacks
    logpath = generate_unique_logpath(args.logdir, args.model)
    tbcb = TensorBoard(log_dir=logpath)

    print("=" * 20)
    print("The logs will be saved in {}".format(logpath))
    print("=" * 20)

    checkpoint_filepath = os.path.join(logpath, "best_model.h5")
    checkpoint_cb = ModelCheckpoint(checkpoint_filepath, save_best_only=True)

    # Write down the summary of the experiment
    summary_text = """
    ## Executed command

    {command}

    ## Args

    {args}

    ## Architecture

    """.format(command=" ".join(sys.argv), args=args)
    with open(os.path.join(logpath, "summary.txt"), 'w') as f:
        f.write(summary_text)

    writer = tf.summary.create_file_writer(os.path.join(logpath, 'summary'))
    with writer.as_default():
        tf.summary.text("Summary", summary_text, 0)

    # Compilation
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    # Training
    if args.data_augment:
        ffit = functools.partial(model.fit,
                                 train_data,
                                 steps_per_epoch=50000 // 128)
    else:
        ffit = functools.partial(model.fit, *train_data, batch_size=128)

    ffit(epochs=50,
         verbose=1,
         validation_data=val_data,
         callbacks=[tbcb, checkpoint_cb])

    with h5py.File(checkpoint_filepath, 'a') as f:
        if 'optimizer_weights' in f.keys():
            del f['optimizer_weights']

    # Evaluation of the best model
    model = load_model(checkpoint_filepath)
    score = model.evaluate(*test_data, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
コード例 #9
0
def train(cfg_file=None,
          model_dir=None,
          result_path=None,
          create_folder=False,
          display_step=50,
          summary_step=5,
          pickle_result=False):
    model_dir = Path(model_dir)
    model_dir.mkdir(parents=True, exist_ok=True)
    eval_checkpoint_dir = model_dir / 'eval_checkpoints'
    eval_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    if result_path is None:
        result_path = model_dir / 'results'
    config_file_bkp = "pipeline.config"
    shutil.copyfile(cfg_file, str(model_dir / config_file_bkp))

    config = cfg_from_yaml_file(cfg_file, cfg)
    input_cfg = config.TRAIN_INPUT_READER
    eval_input_cfg = config.EVAL_INPUT_READER
    model_cfg = config.MODEL
    train_cfg = config.TRAIN_CONFIG
    class_names = config.CLASS_NAMES
    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = core.build_voxel_generator(config.VOXEL_GENERATOR)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = core.build_box_coder(config.BOX_CODER)
    target_assigner_cfg = config.TARGET_ASSIGNER
    target_assigner = core.build_target_assigner(target_assigner_cfg, bv_range,
                                                 box_coder)
    ######################
    # BUILD NET
    ######################
    center_limit_range = model_cfg.POST_PROCESSING.post_center_limit_range
    net = models.build_network(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    # net_train = torch.nn.DataParallel(net).cuda()
    print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)

    ######################
    # BUILD OPTIMIZER
    ######################
    # we need global_step to create lr_scheduler, so restore net first.
    libs.tools.try_restore_latest_checkpoints(model_dir, [net])
    gstep = net.get_global_step() - 1
    optimizer_cfg = train_cfg.OPTIMIZER
    if train_cfg.ENABLE_MIXED_PRECISION:
        net.half()
        net.metrics_to_float()
        net.convert_norm_to_float(net)
    optimizer = core.build_optimizer(optimizer_cfg, net.parameters())
    if train_cfg.ENABLE_MIXED_PRECISION:
        loss_scale = train_cfg.LOSS_SCALE_FACTOR
        mixed_optimizer = libs.tools.MixedPrecisionWrapper(
            optimizer, loss_scale)
    else:
        mixed_optimizer = optimizer

# must restore optimizer AFTER using MixedPrecisionWrapper
    libs.tools.try_restore_latest_checkpoints(model_dir, [mixed_optimizer])
    lr_scheduler = core.build_lr_schedules(optimizer_cfg, optimizer, gstep)
    if train_cfg.ENABLE_MIXED_PRECISION:
        float_dtype = torch.float16
    else:
        float_dtype = torch.float32
    ######################
    # PREPARE INPUT
    ######################
    dataset = core.build_input_reader(input_cfg,
                                      model_cfg,
                                      training=True,
                                      voxel_generator=voxel_generator,
                                      target_assigner=target_assigner)
    eval_dataset = core.build_input_reader(input_cfg,
                                           model_cfg,
                                           training=False,
                                           voxel_generator=voxel_generator,
                                           target_assigner=target_assigner)

    def _worker_init_fn(worker_id):
        time_seed = np.array(time.time(), dtype=np.int32)
        np.random.seed(time_seed + worker_id)
        print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=input_cfg.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=input_cfg.NUM_WORKERS,
                                             pin_memory=False,
                                             collate_fn=merge_second_batch,
                                             worker_init_fn=_worker_init_fn)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_input_cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=eval_input_cfg.NUM_WORKERS,
        pin_memory=False,
        collate_fn=merge_second_batch)
    data_iter = iter(dataloader)

    ######################
    # TRAINING
    ######################
    log_path = model_dir / 'log.txt'
    logf = open(log_path, 'a')
    # logf.write(proto_str)
    logf.write("\n")

    total_step_elapsed = 0
    remain_steps = train_cfg.STEPS - net.get_global_step()
    t = time.time()
    ckpt_start_time = t
    #total_loop = train_cfg.STEPS // train_cfg.STEPS_PER_EVAL + 1
    total_loop = remain_steps // train_cfg.STEPS_PER_EVAL + 1
    clear_metrics_every_epoch = train_cfg.CLEAR_METRICS_EVERY_EPOCH

    if train_cfg.STEPS % train_cfg.STEPS_PER_EVAL == 0:
        total_loop -= 1
    mixed_optimizer.zero_grad()
    try:

        for _ in range(total_loop):
            if total_step_elapsed + train_cfg.STEPS_PER_EVAL > train_cfg.STEPS:
                steps = train_cfg.STEPS % train_cfg.STEPS_PER_EVAL
            else:
                steps = train_cfg.STEPS_PER_EVAL

            for step in range(steps):

                lr_scheduler.step()
                try:
                    example = next(data_iter)
                except StopIteration:
                    print("end epoch")
                    if clear_metrics_every_epoch:
                        net.clear_metrics()
                    data_iter = iter(dataloader)
                    example = next(data_iter)
                example_torch = example_convert_to_torch(example,
                                                         float_dtype,
                                                         device="cuda:0")
                batch_size = example["anchors"].shape[0]

                ret_dict = net(example_torch)

                # box_preds = ret_dict["box_preds"]
                cls_preds = ret_dict["cls_preds"]
                loss = ret_dict["loss"].mean()
                cls_loss_reduced = ret_dict["cls_loss_reduced"].mean()
                loc_loss_reduced = ret_dict["loc_loss_reduced"].mean()
                cls_pos_loss = ret_dict["cls_pos_loss"]
                cls_neg_loss = ret_dict["cls_neg_loss"]
                loc_loss = ret_dict["loc_loss"]
                cls_loss = ret_dict["cls_loss"]
                dir_loss_reduced = ret_dict["dir_loss_reduced"]
                cared = ret_dict["cared"]
                labels = example_torch["labels"]
                if train_cfg.ENABLE_MIXED_PRECISION:
                    loss *= loss_scale

                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 10.0)
                mixed_optimizer.step()
                mixed_optimizer.zero_grad()
                net.update_global_step()
                net_metrics = net.update_metrics(cls_loss_reduced,
                                                 loc_loss_reduced, cls_preds,
                                                 labels, cared)
                step_time = (time.time() - t)
                t = time.time()
                metrics = {}
                num_pos = int((labels > 0)[0].float().sum().cpu().numpy())
                num_neg = int((labels == 0)[0].float().sum().cpu().numpy())
                if 'anchors_mask' not in example_torch:
                    num_anchors = example_torch['anchors'].shape[1]
                else:
                    num_anchors = int(example_torch['anchors_mask'][0].sum())
                global_step = net.get_global_step()
                if global_step % display_step == 0:
                    loc_loss_elem = [
                        float(loc_loss[:, :, i].sum().detach().cpu().numpy() /
                              batch_size) for i in range(loc_loss.shape[-1])
                    ]
                    metrics["step"] = global_step
                    metrics["steptime"] = step_time
                    metrics.update(net_metrics)
                    metrics["loss"] = {}
                    metrics["loss"]["loc_elem"] = loc_loss_elem
                    metrics["loss"]["cls_pos_rt"] = float(
                        cls_pos_loss.detach().cpu().numpy())
                    metrics["loss"]["cls_neg_rt"] = float(
                        cls_neg_loss.detach().cpu().numpy())

                    if model_cfg.BACKBONE.use_direction_classifier:
                        metrics["loss"]["dir_rt"] = float(
                            dir_loss_reduced.detach().cpu().numpy())
                    metrics["num_vox"] = int(example_torch["voxels"].shape[0])
                    metrics["num_pos"] = int(num_pos)
                    metrics["num_neg"] = int(num_neg)
                    metrics["num_anchors"] = int(num_anchors)
                    metrics["lr"] = float(
                        mixed_optimizer.param_groups[0]['lr'])

                    metrics["image_idx"] = example['image_idx'][0]
                    flatted_metrics = flat_nested_json_dict(metrics)
                    flatted_summarys = flat_nested_json_dict(metrics, "/")
                    # for k,v in flatted_summarys.items():
                    #     if isinstance(v,(list,tuple)):
                    #         v = {str(i): e for i,e in enumerate(v)}
                    #         writer.add_scalars(k,v,global_step)
                    #     else:
                    #         writer.add_scalars(k,v,global_step)
                    metrics_str_list = []
                    for k, v in flatted_metrics.items():
                        if isinstance(v, float):
                            metrics_str_list.append(f"{k}={v:.3}")
                        elif isinstance(v, (list, tuple)):
                            if v and isinstance(v[0], float):
                                v_str = ', '.join([f"{e:.3}" for e in v])
                                metrics_str_list.append(f"{k}=[{v_str}]")
                            else:
                                metrics_str_list.append(f"{k}={v}")
                        else:
                            metrics_str_list.append(f"{k}={v}")
                    log_str = ', '.join(metrics_str_list)
                    print(log_str, file=logf)
                    print(log_str)
                ckpt_elasped_time = time.time() - ckpt_start_time
                if ckpt_elasped_time > train_cfg.SAVE_CHECKPOINTS_SECS:
                    libs.tools.save_models(model_dir, [net, optimizer],
                                           net.get_global_step())
                    ckpt_start_time = time.time()

            total_step_elapsed += steps
            libs.tools.save_models(model_dir, [net, optimizer],
                                   net.get_global_step())
            # Ensure that all evaluation points are saved forever
            libs.tools.save_models(eval_checkpoint_dir, [net, optimizer],
                                   net.get_global_step(),
                                   max_to_keep=100)

            net.eval()
            result_path_step = result_path / f"step_{net.get_global_step()}"
            result_path_step.mkdir(parents=True, exist_ok=True)
            print("#################################")
            print("#################################", file=logf)
            print("# EVAL")
            print("# EVAL", file=logf)
            print("#################################")
            print("#################################", file=logf)
            print("Generate output labels...")
            print("Generate output labels...", file=logf)
            t = time.time()
            dt_annos = []
            prog_bar = ProgressBar()
            prog_bar.start(len(eval_dataset) // eval_input_cfg.BATCH_SIZE + 1)
            for example in iter(eval_dataloader):
                example = example_convert_to_torch(example, float_dtype)
                if pickle_result:
                    dt_annos += predict_kitti_to_anno(net, example,
                                                      class_names,
                                                      center_limit_range,
                                                      model_cfg.LIDAR_INPUT)
                else:
                    _predict_kitti_to_file(net, example, result_path_step,
                                           class_names, center_limit_range,
                                           model_cfg.LIDAR_INPUT)
                prog_bar.print_bar()
            sec_per_ex = len(eval_dataset) / (time.time() - t)
            print(f"avg forward time per example: {net.avg_forward_time:.3f}")
            print(
                f"avg postprocess time per example: {net.avg_postprocess_time:.3f}"
            )
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:')
            print(f'generate label finished({sec_per_ex:.2f}/s). start eval:',
                  file=logf)
            gt_annos = [
                info["annos"] for info in eval_dataset.dataset.kitti_infos
            ]
            if not pickle_result:
                dt_annos = kitti.get_label_annos(result_path_step)
                result, mAPbbox, mAPbev, mAP3d, mAPaos = get_official_eval_result(
                    gt_annos, dt_annos, class_names, return_data=True)
            print(result, file=logf)
            print(result)

            result = get_coco_eval_result(gt_annos, dt_annos, class_names)
            print(result, file=logf)
            print(result)
            net.train()

    except Exception as e:
        libs.tools.save_models(model_dir, [net, optimizer],
                               net.get_global_step())
コード例 #10
0
def convert(config_path, weights_file, trt_path, max_voxel_num=12000):
    """train a VoxelNet model specified by a config file.
    """

    trt_path = pathlib.Path(trt_path)
    model_logs_path = trt_path / 'model_logs'
    model_logs_path.mkdir(parents=True, exist_ok=True)

    config_file_bkp = 'pipeline.config'
    shutil.copyfile(config_path, str(model_logs_path / config_file_bkp))
    shutil.copyfile(weights_file,
                    str(model_logs_path / weights_file.split('/')[-1]))

    config = cfg_from_yaml_file(config_path, cfg)
    model_cfg = config.MODEL

    ######################
    # BUILD VOXEL GENERATOR
    ######################
    voxel_generator = build_voxel_generator(config.VOXEL_GENERATOR)
    ######################
    # BUILD TARGET ASSIGNER
    ######################
    bv_range = voxel_generator.point_cloud_range[[0, 1, 3, 4]]
    box_coder = build_box_coder(config.BOX_CODER)
    target_assigner_cfg = config.TARGET_ASSIGNER
    target_assigner = build_target_assigner(target_assigner_cfg, bv_range,
                                            box_coder)
    ######################
    # BUILD NET
    ######################
    model_cfg.XAVIER = True
    net = build_network(model_cfg, voxel_generator, target_assigner)
    net.cuda()
    # net_train = torch.nn.DataParallel(net).cuda()
    # print("num_trainable parameters:", len(list(net.parameters())))
    # for n, p in net.named_parameters():
    #     print(n, p.shape)

    state_dict = torch.load(weights_file)
    net.load_state_dict(state_dict, strict=False)
    net.eval()

    #tensorrt引擎路径
    pfn_trt_path = str(trt_path / "pfn.trt")
    backbone_trt_path = str(trt_path / "backbone.trt")

    #生成模型虚假输入数据用于编译tensorrt引擎

    example_tensor = generate_tensor_list(max_voxel_num,
                                          float_type=torch.float32,
                                          device='cuda')

    print(
        '----------------------------------------------------------------------------'
    )
    print(
        "************ TensorRT: The PFN subnetwork is being transformed *************"
    )
    print(
        '----------------------------------------------------------------------------'
    )
    pfn_trt = torch2trt(net.pfn,
                        example_tensor,
                        fp16_mode=True,
                        max_workspace_size=1 << 20)
    torch.save(pfn_trt.state_dict(), pfn_trt_path)

    print(
        '------------------------------------------------------------------------------'
    )
    print(
        "******** TensorRT: The BackBone subnetwork(RPN) is being transformed *********"
    )
    print(
        '------------------------------------------------------------------------------'
    )
    pc_range = np.array(config.VOXEL_GENERATOR.POINT_CLOUD_RANGE)
    vs = np.array(config.VOXEL_GENERATOR.VOXEL_SIZE)
    fp_size = ((pc_range[3:] - pc_range[:3]) / vs)[::-1].astype(np.int)
    rpn_input = torch.ones((1, 64, fp_size[1], fp_size[2]),
                           dtype=torch.float32,
                           device='cuda')
    rpn_trt = torch2trt(net.rpn, [rpn_input],
                        fp16_mode=True,
                        max_workspace_size=1 << 20)
    torch.save(rpn_trt.state_dict(), backbone_trt_path)

    print("Done!")