Esempio n. 1
0
def main(args):

    if args.dataset:
        dataset = get_dataset(args.dataset, False, False, None)
        classes_num = dataset.classes_num
    else:
        classes_num = args.classes_num

    net = locate(args.model)(classes_num, fc_detection_head=False)
    load_checkpoint(net, args.ckpt)

    if args.show_flops:
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()

    printable_graph = onnx_export(net, args.input_size, args.output_file, check=args.check, verbose=args.verbose)

    if args.verbose:
        logging.info(printable_graph)

    if args.show_flops:
        net.stop_flops_count()
        logging.info('Computational complexity: {}'.format(flops_to_string(net.compute_average_flops_cost())))
        if args.verbose:
            print_model_with_flops(net)
Esempio n. 2
0
def main(args):
    args.output_dir = osp.abspath(args.output_dir)

    downloaders = dict(web=download_file_from_web,
                       google_drive=download_file_from_google_drive)

    with open(args.model_zoo, 'rt') as f:
        model_zoo = yaml.load(f)['models']

    logging.info('Models to be fetched:')
    for target in model_zoo:
        logging.info('\t{} ({}, {})'.format(colored(target['name'], 'green'),
                                            target['dataset'],
                                            target['framework']))
    logging.info('')

    for target in model_zoo:
        output_file = osp.join(args.output_dir, 'raw', target['dst_file'])
        mkdir_for_file(output_file)
        logging.info('Fetching {} ({}, {})'.format(
            colored(target['name'], 'green'), target['dataset'],
            target['framework']))
        try:
            # Download weights.
            if target['storage_type'] in downloaders:
                downloaders[target['storage_type']](target['url'], output_file)
            else:
                logging.warning(
                    'No downloaders available for storage {}'.format(
                        target['storage_type']))
                continue

            # Convert weights.
            logging.info('Downloaded to {}'.format(output_file))
            if target.get('weights_converter', None):
                logging.info('Converting weights...')
                converter = locate(target['weights_converter'])
                if converter is None:
                    logging.warning('Invalid weights converter {}'.format(
                        target['weights_converter']))
                    continue
                output_converted_file = osp.join(args.output_dir, 'converted',
                                                 target['dst_file'])
                output_converted_file = osp.splitext(
                    output_converted_file)[0] + '.pth'
                mkdir_for_file(output_converted_file)
                try:
                    converter(output_file, output_converted_file)
                    logging.info('Converted weights file saved to {}'.format(
                        output_converted_file))
                except Exception as ex:
                    logging.warning('Failed to convert weights.')
                    logging.warning(ex)
                    continue

                if target.get('convert_to_ir', False):
                    # Convert to ONNX.
                    logging.info('Exporting to ONNX...')
                    output_onnx_file = osp.join(args.output_dir, 'onnx',
                                                target['dst_file'])
                    output_onnx_file = osp.splitext(
                        output_onnx_file)[0] + '.onnx'
                    mkdir_for_file(output_onnx_file)
                    net = locate(target['model'])(81, fc_detection_head=False)
                    load_checkpoint(net, output_converted_file, verbose=False)
                    onnx_export(net, target['input_size'], output_onnx_file)
                    logging.info(
                        'ONNX file is saved to {}'.format(output_onnx_file))

                    # Convert to IR.
                    logging.info('Converting to IR...')
                    output_ir_dir = osp.join(args.output_dir, 'ir',
                                             target['dst_file'])
                    mkdir_for_file(output_ir_dir)
                    output_ir_dir = osp.dirname(output_ir_dir)
                    status = call([
                        args.model_optimizer, '--framework', 'onnx',
                        '--input_model', output_onnx_file, '--output_dir',
                        output_ir_dir, '--input', 'im_data,im_info',
                        '--output', 'boxes,scores,classes,batch_ids,raw_masks',
                        '--mean_values', 'im_data{},im_info[0,0,0]'.format(
                            str(target['mean_pixel']).replace(' ', ''))
                    ])
                    if status:
                        logging.warning('Failed to convert model to IR.')
                    else:
                        logging.info(
                            'IR files saved to {}'.format(output_ir_dir))

        except Exception as ex:
            logging.warning(repr(ex))
Esempio n. 3
0
def main(args):
    transforms = Compose([
        Resize(max_size=args.fit_max_image_size,
               window_size=args.fit_window_size,
               size=args.size),
        ToTensor(),
        Normalize(mean=args.mean_pixel, std=[1., 1., 1.], rgb=args.rgb),
    ])
    dataset = get_dataset(args.dataset, False, False, transforms)
    logging.info(dataset)
    batch_size = 1

    logging.info('Using {} backend'.format(args.backend))

    logging.info('Loading network...')
    if args.backend == 'pytorch':
        net = locate(args.pytorch_model_class)(dataset.classes_num)
        net.eval()
        load_checkpoint(net, args.checkpoint_file_path)
        if torch.cuda.is_available():
            net = net.cuda()
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()
    elif args.backend == 'openvino':
        net = MaskRCNNOpenVINO(
            args.openvino_model_path,
            args.checkpoint_file_path,
            device=args.device,
            plugin_dir=args.plugin_dir,
            cpu_extension_lib_path=args.cpu_extension,
            collect_perf_counters=args.show_performance_counters)
    else:
        raise ValueError('Unknown backend "{}"'.format(args.backend))

    viz = Visualizer(dataset.classes,
                     confidence_threshold=args.prob_threshold,
                     show_boxes=args.show_boxes,
                     show_scores=args.show_scores)

    inference_timer = Timer(cuda_sync=True, warmup=1)
    timer = Timer(cuda_sync=False, warmup=1)
    timer.tic()

    logging.info('Configuring data source...')
    if args.video:
        try:
            args.video = int(args.video)
        except ValueError:
            pass
        demo_dataset = VideoDataset(args.video,
                                    labels=dataset.classes,
                                    transforms=transforms)
        num_workers = 0
        tracker = StaticIOUTracker()
    else:
        demo_dataset = ImagesDataset(args.images,
                                     labels=dataset.classes,
                                     transforms=transforms)
        num_workers = 1
        tracker = None

    data_loader = torch.utils.data.DataLoader(demo_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              shuffle=False,
                                              collate_fn=collate)

    logging.info('Processing data...')
    frames_num = len(demo_dataset)
    for data_batch in tqdm(
            iter(data_loader),
            total=frames_num if frames_num != sys.maxsize else 0):
        im_data = data_batch['im_data']
        im_info = data_batch['im_info']
        if torch.cuda.is_available():
            im_data = [i.cuda() for i in im_data]
            im_info = [i.cuda() for i in im_info]
        with torch.no_grad(), inference_timer:
            boxes, classes, scores, _, masks = net(im_data, im_info)

        meta = data_batch['meta'][0]
        scores, classes, boxes, masks = postprocess(
            scores,
            classes,
            boxes,
            masks,
            im_h=meta['original_size'][0],
            im_w=meta['original_size'][1],
            im_scale_y=meta['processed_size'][0] / meta['original_size'][0],
            im_scale_x=meta['processed_size'][1] / meta['original_size'][1],
            full_image_masks=True,
            encode_masks=False,
            confidence_threshold=args.prob_threshold)

        masks_ids = tracker(masks, classes) if tracker is not None else None
        image = data_batch['original_image'][0]
        visualization = viz(image,
                            boxes,
                            classes,
                            scores,
                            segms=masks,
                            ids=masks_ids)
        fps = 1 / timer.toc()
        if args.show_fps:
            visualization = cv2.putText(visualization,
                                        'FPS: {:>2.2f}'.format(fps), (30, 30),
                                        cv2.FONT_HERSHEY_SIMPLEX, 1,
                                        (0, 0, 255), 2)
        cv2.imshow('result', visualization)
        key = cv2.waitKey(args.delay)
        if key == 27:
            break
        timer.tic()

    if inference_timer.average_time > 0:
        logging.info('Average inference FPS: {:3.2f}'.format(
            1 / inference_timer.average_time))

    if args.backend == 'pytorch':
        net.stop_flops_count()
        if args.show_flops:
            logging.info('Average FLOPs:  {}'.format(
                flops_to_string(net.compute_average_flops_cost())))
        if args.show_layers_flops:
            logging.info('Thorough computational complexity statistics:')
            print_model_with_flops(net)
        if torch.cuda.is_available():
            logging.info('GPU memory footprint:')
            logging.info('\tMax allocated: {:.2f} MiB'.format(
                torch.cuda.max_memory_allocated() / 1024**2))
            logging.info('\tMax cached:    {:.2f} MiB'.format(
                torch.cuda.max_memory_cached() / 1024**2))
    else:
        if args.show_performance_counters:
            net.print_performance_counters()

    cv2.destroyAllWindows()
    del net
Esempio n. 4
0
def main():
    """ Does export to onnx. """

    args = parse_args()

    with open(args.model) as file:
        config = json.load(file)
    text_spotter = make_text_detector(**config['model'])(
        2, fc_detection_head=False, shape=args.input_size)
    load_checkpoint(text_spotter, args.ckpt)
    text_spotter.export_mode = True

    net = text_spotter

    if args.show_flops:
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()

    # Export of text detection part (Mask-RCNN subgraph).
    printable_graph = onnx_export(net,
                                  args.input_size,
                                  args.output_folder,
                                  check=args.check,
                                  verbose=args.verbose)
    if args.verbose:
        logging.info(printable_graph)

    if args.show_flops:
        net.stop_flops_count()
        logging.info(
            'Computational complexity of text detection part: {}'.format(
                flops_to_string(net.compute_average_flops_cost())))
        if args.verbose:
            print_model_with_flops(net)

    # Export of text recognition encoder
    net = text_spotter.text_recogn_head.encoder
    if args.show_flops:
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()

    printable_graph = export_to_onnx_text_recognition_encoder(
        net, text_spotter.text_recogn_head.input_feature_size,
        args.output_folder)
    if args.verbose:
        logging.info(printable_graph)

    if args.show_flops:
        net.stop_flops_count()
        logging.info(
            'Computational complexity of text recognition encoder part: {}'.
            format(flops_to_string(net.compute_average_flops_cost())))
        if args.verbose:
            print_model_with_flops(net)

    # Export of text recognition decoder
    net = text_spotter.text_recogn_head.decoder
    if args.show_flops:
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()

    printable_graph = export_to_onnx_text_recognition_decoder(
        net, text_spotter.text_recogn_head.input_feature_size,
        args.output_folder)
    if args.verbose:
        logging.info(printable_graph)

    if args.show_flops:
        net.stop_flops_count()
        logging.info(
            'Computational complexity of text recognition decoder part: {}'.
            format(flops_to_string(net.compute_average_flops_cost())))
        if args.verbose:
            print_model_with_flops(net)
Esempio n. 5
0
    def __init__(self):
        super().__init__()
        self.identifier = 'instance-segmentation-security-0050'
        self.description = 'Training of instance-segmentation-security-0050'
        self.root_directory = osp.join(osp.dirname(osp.abspath(__file__)), '..')
        self.run_directory = self.create_run_directory(osp.join(self.root_directory, 'outputs'))

        setup_logging(file_path=osp.join(self.run_directory, 'log.txt'))

        logger.info('Running {}'.format(self.identifier))
        logger.info(self.description)
        logger.info('Working directory "{}"'.format(self.run_directory))

        self.batch_size = 32
        self.virtual_iter_size = 1

        # Training dataset.
        training_transforms = Compose(
            [
                RandomResize(mode='size', heights=(416, 448, 480, 512, 544), widths=(416, 448, 480, 512, 544)),
                RandomHorizontalFlip(prob=0.5),
                ToTensor(),
                Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], rgb=False),
            ],
        )
        training_dataset_name = 'coco_2017_train'
        logger.info('Training dataset {}'.format(training_dataset_name))
        training_dataset = get_dataset(training_dataset_name, True, True, training_transforms)
        logger.info(training_dataset)
        self.training_data_loader = torch.utils.data.DataLoader(
            training_dataset, batch_size=self.batch_size, num_workers=0,
            shuffle=True, drop_last=True, collate_fn=collate
        )

        # Validation datasets.
        validation_transforms = Compose(
            [
                Resize(size=[480, 480]),
                ToTensor(),
                Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], rgb=False),
            ]
        )
        validation_datasets = []
        validation_dataset_name = 'coco_2017_val'
        logger.info('Validation dataset #{}: {}'.format(len(validation_datasets) + 1, validation_dataset_name))
        validation_datasets.append(get_dataset(validation_dataset_name, False, False, validation_transforms))
        logger.info(validation_datasets[-1])

        self.validation_data_loaders = []
        for validation_dataset in validation_datasets:
            self.validation_data_loaders.append(torch.utils.data.DataLoader(
                validation_dataset,
                batch_size=1, num_workers=8,
                shuffle=False, drop_last=False, collate_fn=collate)
            )
        self.validate_every = 10000

        for validation_dataset in validation_datasets:
            assert training_dataset.classes_num == validation_dataset.classes_num

        # Model and optimizer.
        logger.info('Model:')
        self.model = Model(training_dataset.classes_num)
        logger.info(self.model)

        self.training_iterations_num = 270000
        lr_scheduler_milestones = [220000, 250000]
        base_lr = 0.02
        weight_decay = 0.0001
        logger.info('Optimizer:')
        self.optimizer = torch.optim.SGD(self.setup_optimizer(self.model, base_lr, weight_decay),
                                         lr=base_lr, weight_decay=weight_decay, momentum=0.9)
        logger.info(self.optimizer)
        logger.info('Learning Rate scheduler:')
        self.lr_scheduler = MultiStepLRWithWarmUp(
            self.optimizer,
            milestones=lr_scheduler_milestones,
            warmup_iters=1000,
            warmup_method='linear',
            warmup_factor_base=0.333,
            gamma=0.1,
            last_epoch=0
        )
        logger.info(self.lr_scheduler)

        self.start_step = 0
        checkpoint_file_path = osp.join(self.root_directory, 'data', 'pretrained_models',
                                        'converted', 'imagenet', 'detectron', 'resnet50.pth')
        if not osp.exists(checkpoint_file_path):
            raise IOError('Initial checkpoint file "{}" does not exist. '
                          'Please fetch pre-trained backbone networks using '
                          'tools/download_pretrained_weights.py script first.'.format(checkpoint_file_path))
        logger.info('Loading weights from "{}"'.format(checkpoint_file_path))
        load_checkpoint(self.model.backbone, checkpoint_file_path)

        # Loggers and misc. stuff.
        self.loggers = [TextLogger(logger),
                        TensorboardLogger(self.run_directory)]
        self.log_every = 50

        self.checkpoint_every = 10000
Esempio n. 6
0
def main(args):
    """ Tests text spotter. """

    transforms = Compose(
        [
            Resize(size=args.size),
            ToTensor(),
            Normalize(mean=args.mean_pixel, std=args.std_pixel, rgb=args.rgb),
        ]
    )
    dataset = get_dataset(args.dataset, False, False, transforms,
                          alphabet_decoder=AlphabetDecoder())
    logging.info(dataset)
    num_workers = args.num_workers

    inference_timer = Timer()

    logging.info('Using {} backend'.format(args.backend))

    logging.info('Loading network...')
    batch_size = 1
    if args.backend == 'pytorch':
        with open(args.pytorch_model_class) as file:
            config = json.load(file)
        net = make_text_detector(**config['model'])(dataset.classes_num,
                                                    force_max_output_size=False, shape=args.size)
        net.eval()
        load_checkpoint(net, args.checkpoint_file_path)
        if torch.cuda.is_available():
            net = net.cuda()
        net = add_flops_counting_methods(net)
        net.reset_flops_count()
        net.start_flops_count()
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True
            net = net.cuda()
            net = ShallowDataParallel(net)
    elif args.backend == 'openvino':
        net = TextMaskRCNNOpenVINO(args.openvino_detector_model_path,
                                   args.openvino_encoder_model_path,
                                   args.openvino_decoder_model_path,
                                   collect_perf_counters=args.show_performance_counters)
    else:
        raise ValueError('Unknown backend "{}"'.format(args.backend))

    logging.info('Using batch size {}'.format(batch_size))
    logging.info('Number of prefetching processes {}'.format(num_workers))
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=collate
    )

    logging.info('Processing dataset...')
    boxes_all = []
    masks_all = []
    classes_all = []
    scores_all = []
    text_probs_all = []
    for data_batch in tqdm(iter(data_loader)):
        batch_meta = data_batch['meta']
        actual_batch_size = len(batch_meta)
        with torch.no_grad(), inference_timer:
            boxes, classes, scores, batch_ids, masks, text_probs = net(**data_batch)

        im_heights = [meta['original_size'][0] for meta in batch_meta]
        im_widths = [meta['original_size'][1] for meta in batch_meta]
        im_scale_y = [meta['processed_size'][0] / meta['original_size'][0] for meta in batch_meta]
        im_scale_x = [meta['processed_size'][1] / meta['original_size'][1] for meta in batch_meta]
        scores, classes, boxes, masks, text_probs = postprocess_batch(
            batch_ids, scores, classes, boxes, masks, text_probs, actual_batch_size,
            im_h=im_heights,
            im_w=im_widths,
            im_scale_y=im_scale_y,
            im_scale_x=im_scale_x,
            full_image_masks=True, encode_masks=True,
            confidence_threshold=args.prob_threshold)

        boxes_all.extend(boxes)
        masks_all.extend(masks)
        classes_all.extend(classes)
        scores_all.extend(scores)
        text_probs_all.extend(text_probs)

    try:
        del data_loader
    except ConnectionResetError:
        pass

    logging.info('Evaluating results...')
    evaluation_results = dataset.evaluate(scores_all, classes_all, boxes_all, masks_all,
                                          text_probs_all, dump='dump', visualize=args.visualize)
    logging.info(evaluation_results)

    logging.info('Average inference time {}'.format(inference_timer.average_time))

    if args.backend == 'pytorch':
        if torch.cuda.is_available():
            net = net.module
        net.stop_flops_count()
        if args.show_flops:
            logging.info(
                'Average FLOPs:  {}'.format(flops_to_string(net.compute_average_flops_cost())))
        if args.show_layers_flops:
            logging.info('Thorough computational complexity statistics:')
            print_model_with_flops(net)
    else:
        if args.show_performance_counters:
            net.print_performance_counters()

    del net
    def load_checkpoint(net,
                        optimizer,
                        load_ckpt=None,
                        load_backbone=None,
                        resume=False):
        start_step = 0
        if load_ckpt:
            logging.info('loading checkpoint "{}"'.format(load_ckpt))
            checkpoint = torch.load(load_ckpt,
                                    map_location=lambda storage, loc: storage)
            weight_utils.load_rcnn_ckpt(net, checkpoint['model'])
            if resume:
                start_step = checkpoint['step']
                optimizer.load_state_dict(checkpoint['optimizer'])

                corrected_matcher = checkpoint.get('optimizer_corrected', {})
                if len(corrected_matcher) > 0:
                    for group in optimizer.param_groups:
                        for param_original in group['params']:
                            if 'momentum_buffer' in optimizer.state[
                                    param_original]:
                                for param_loaded, buffer in corrected_matcher.items(
                                ):
                                    shapes_are_equal = np.array_equal(
                                        list(param_original.shape),
                                        list(param_loaded.shape))
                                    if shapes_are_equal and torch.all(
                                            torch.eq(
                                                param_original.data,
                                                param_loaded.data.cuda())):
                                        optimizer.state[param_original][
                                            'momentum_buffer'] = buffer.cuda()
                                        break
                else:
                    # If a checkpoint does not have additional dictionary with matched
                    # parameters and its buffers, just match them by shapes
                    if len(
                            optimizer.state
                    ) > 0:  # It means that a checkpoint has momentum_buffer
                        used_buffers = {}
                        copy_buf = None
                        for p in optimizer.state.keys():
                            used_buffers[p] = False
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                for p, buffer in optimizer.state.items():
                                    if 'momentum_buffer' not in buffer:
                                        continue
                                    if np.array_equal(list(param.shape), list(buffer['momentum_buffer'].shape)) and not \
                                       used_buffers[p]:
                                        copy_buf = optimizer.state[param][
                                            'momentum_buffer'].cuda()
                                        optimizer.state[param][
                                            'momentum_buffer'] = buffer[
                                                'momentum_buffer'].cuda()
                                        optimizer.state[p][
                                            'momentum_buffer'] = copy_buf.cuda(
                                            )
                                        used_buffers[param] = True
                        del used_buffers
                        del copy_buf
                logging.info('Resume training from {} step'.format(start_step))

            del checkpoint
            torch.cuda.empty_cache()

        if load_backbone:
            logging.info(
                'loading backbone weights from "{}"'.format(load_backbone))
            assert hasattr(net, 'backbone')
            weight_utils.load_checkpoint(net.backbone, load_backbone)

        return start_step
    def __init__(self, work_dir, config):
        super().__init__()
        self.identifier = config['identifier']
        self.description = config['description']
        self.root_directory = work_dir if work_dir else osp.join(
            osp.dirname(osp.abspath(__file__)), '..')
        self.run_directory = self.create_run_directory(
            osp.join(self.root_directory, 'models'))

        setup_logging(file_path=osp.join(self.run_directory, 'log.txt'))

        logger.info('Running {}'.format(self.identifier))
        logger.info(self.description)
        logger.info('Working directory "{}"'.format(self.run_directory))

        self.batch_size = config['training_details']['batch_size']
        self.virtual_iter_size = config['training_details'][
            'virtual_iter_size']

        model_class = make_text_detector(**config['model'])

        alphabet_decoder = AlphabetDecoder()

        # Training dataset.
        training_transforms = Compose([
            getattr(sys.modules[__name__], k)(**v)
            for k, v in config['training_transforms'].items()
        ] + [AlphabetDecodeTransform(alphabet_decoder)])

        training_dataset_name = config['training_dataset_name']
        logger.info('Training dataset {}'.format(training_dataset_name))
        training_dataset = get_dataset(training_dataset_name,
                                       True,
                                       True,
                                       training_transforms,
                                       alphabet_decoder=alphabet_decoder,
                                       remove_images_without_text=True)
        logger.info(training_dataset)
        self.training_data_loader = torch.utils.data.DataLoader(
            training_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=True,
            drop_last=True,
            collate_fn=collate)

        # Validation datasets.
        validation_transforms = Compose([
            getattr(sys.modules[__name__], k)(**v)
            for k, v in config['validation_transforms'].items()
        ])
        self.confidence_threshold = config['validation_confidence_threshold']
        validation_datasets = []
        validation_dataset_name = config['validation_dataset_name']
        logger.info('Validation dataset #{}: {}'.format(
            len(validation_datasets) + 1, validation_dataset_name))
        validation_datasets.append(
            get_dataset(validation_dataset_name,
                        False,
                        False,
                        validation_transforms,
                        alphabet_decoder=alphabet_decoder))
        logger.info(validation_datasets[-1])

        self.validation_data_loaders = []
        for validation_dataset in validation_datasets:
            self.validation_data_loaders.append(
                torch.utils.data.DataLoader(validation_dataset,
                                            batch_size=1,
                                            num_workers=8,
                                            shuffle=False,
                                            drop_last=False,
                                            collate_fn=collate))
        self.validate_every = config['training_details']['validate_every']

        for validation_dataset in validation_datasets:
            assert training_dataset.classes_num == validation_dataset.classes_num

        # Model and optimizer.
        logger.info('Model:')

        self.model = model_class(cls_num=training_dataset.classes_num,
                                 shape=config['shape'],
                                 num_chars=len(alphabet_decoder.alphabet))

        logger.info(self.model)

        self.training_iterations_num = config['training_details'][
            'training_iterations_num']
        lr_scheduler_milestones = config['training_details'][
            'lr_scheduler_milestones']
        base_lr = config['training_details']['base_lr']
        weight_decay = config['training_details']['weight_decay']
        logger.info('Optimizer:')
        self.optimizer = torch.optim.SGD(self.setup_optimizer(
            self.model, base_lr, weight_decay),
                                         lr=base_lr,
                                         weight_decay=weight_decay,
                                         momentum=0.9)
        logger.info(self.optimizer)
        logger.info('Learning Rate scheduler:')
        self.lr_scheduler = MultiStepLRWithWarmUp(
            self.optimizer,
            milestones=lr_scheduler_milestones,
            warmup_iters=1000,
            warmup_method='linear',
            warmup_factor_base=0.333,
            gamma=0.1,
            last_epoch=0)
        logger.info(self.lr_scheduler)

        self.start_step = 0
        if 'backbone_checkpoint' in config and config['backbone_checkpoint']:
            checkpoint_file_path = osp.join(self.root_directory,
                                            config['backbone_checkpoint'])
            if not osp.exists(checkpoint_file_path):
                raise IOError(
                    'Initial checkpoint file "{}" does not exist. '
                    'Please fetch pre-trained backbone networks using '
                    'tools/download_pretrained_weights.py script first.'.
                    format(checkpoint_file_path))
            logger.info(
                'Loading weights from "{}"'.format(checkpoint_file_path))
            load_checkpoint(self.model.backbone,
                            checkpoint_file_path,
                            verbose=True,
                            skip_prefix='text_recogn')

        if 'checkpoint' in config and config['checkpoint']:
            checkpoint_file_path = osp.join(self.root_directory,
                                            config['checkpoint'])
            if not osp.exists(checkpoint_file_path):
                raise IOError('Checkpoint file "{}" does not exist. '.format(
                    checkpoint_file_path))
            logger.info(
                'Loading weights from "{}"'.format(checkpoint_file_path))
            load_checkpoint(self.model, checkpoint_file_path, verbose=True)

        # Loggers and misc. stuff.
        self.loggers = [
            TextLogger(logger),
            TensorboardLogger(self.run_directory)
        ]
        self.log_every = 50

        self.checkpoint_every = config['training_details']['checkpoint_every']