コード例 #1
0
def main(options):
    BLSTM_CHECKPOINT_NAME = options.model_path
    TOKENIZER_PATH = options.vectorizer_path
    MAX_SEQUENCE_LENGTH = options.sequence_length
    INPUT = options.input
    blstm_model = load_model(BLSTM_CHECKPOINT_NAME)

    with open(TOKENIZER_PATH, 'rb') as handle:
        tokenizer = pickle.load(handle)

    label_index = {0: 0, 1: 1, 2: 2, 3: 3}
    if not INPUT:
        str_input = 'this is a string of text with no punctuation this is a new sentence'
        #str_input = 'halloween is officially behind us which means for the next two months or so it is going to be all christmas all the time but before you get sick of the overplayed music and the excessive gift buying why not take advantage and celebrate a little canadas wonderland is launching a brand new winter festival at the end of this month and it just might be the perfect place to geek out and enjoy some holiday cheer wonderland announced the new festival in the summer of 2018 and they revealed the official launch date about a month ago but as new details about the festival continue to emerge it becomes more and more clear that it is sure to be a can not miss event this holiday season'
    else:
        str_input = options.input

    str_split = str_input.split()
    str_chunk = [
        str_split[i:i + MAX_SEQUENCE_LENGTH]
        for i in range(0, len(str_split), MAX_SEQUENCE_LENGTH)
    ]
    str_numeric = np.array(tokenizer.texts_to_sequences(str_chunk))
    str_pad = pad_sequences(str_numeric, MAX_SEQUENCE_LENGTH, padding='post')
    blstm_str_pred = blstm_model.predict(str_pad, batch_size=64, verbose=1)
    blstm_str_trans = Transform(blstm_str_pred, label_index)

    result = []
    for row, chunk in enumerate(str_chunk):
        for col, word in enumerate(chunk):
            if blstm_str_trans[row][col] == 0:
                result.append(word)
            if blstm_str_trans[row][col] == 1:
                result.append(word)
                result.append('<comma>')
            if blstm_str_trans[row][col] == 2:
                result.append(word)
                result.append('<period>')
            if blstm_str_trans[row][col] == 3:
                result.append(word)
                result.append('<question_mark>')
    print(' '.join(result))
コード例 #2
0
ファイル: train_multi.py プロジェクト: zhuyiche/chainercv
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        choices=('ssd300', 'ssd512'),
                        default='ssd300')
    parser.add_argument('--batchsize', type=int, default=32)
    parser.add_argument('--test-batchsize', type=int, default=16)
    parser.add_argument('--iteration', type=int, default=120000)
    parser.add_argument('--step', type=int, nargs='*', default=[80000, 100000])
    parser.add_argument('--out', default='result')
    parser.add_argument('--resume')
    args = parser.parse_args()

    comm = chainermn.create_communicator()
    device = comm.intra_rank

    if args.model == 'ssd300':
        model = SSD300(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')
    elif args.model == 'ssd512':
        model = SSD512(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')

    model.use_preset('evaluate')
    train_chain = MultiboxTrainChain(model)
    chainer.cuda.get_device_from_id(device).use()
    model.to_gpu()

    train = TransformDataset(
        ConcatenatedDataset(VOCBboxDataset(year='2007', split='trainval'),
                            VOCBboxDataset(year='2012', split='trainval')),
        ('img', 'mb_loc', 'mb_label'),
        Transform(model.coder, model.insize, model.mean))

    if comm.rank == 0:
        indices = np.arange(len(train))
    else:
        indices = None
    indices = chainermn.scatter_dataset(indices, comm, shuffle=True)
    train = train.slice[indices]

    # http://chainermn.readthedocs.io/en/latest/tutorial/tips_faqs.html#using-multiprocessiterator
    if hasattr(multiprocessing, 'set_start_method'):
        multiprocessing.set_start_method('forkserver')
    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize //
                                                        comm.size,
                                                        n_processes=2)

    if comm.rank == 0:
        test = VOCBboxDataset(year='2007',
                              split='test',
                              use_difficult=True,
                              return_difficult=True)
        test_iter = chainer.iterators.SerialIterator(test,
                                                     args.test_batchsize,
                                                     repeat=False,
                                                     shuffle=False)

    # initial lr is set to 1e-3 by ExponentialShift
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(), comm)
    optimizer.setup(train_chain)
    for param in train_chain.params():
        if param.name == 'b':
            param.update_rule.add_hook(GradientScaling(2))
        else:
            param.update_rule.add_hook(WeightDecay(0.0005))

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=device)
    trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                               args.out)
    trainer.extend(extensions.ExponentialShift('lr', 0.1, init=1e-3),
                   trigger=triggers.ManualScheduleTrigger(
                       args.step, 'iteration'))

    if comm.rank == 0:
        trainer.extend(DetectionVOCEvaluator(test_iter,
                                             model,
                                             use_07_metric=True,
                                             label_names=voc_bbox_label_names),
                       trigger=triggers.ManualScheduleTrigger(
                           args.step + [args.iteration], 'iteration'))

        log_interval = 10, 'iteration'
        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.observe_lr(), trigger=log_interval)
        trainer.extend(extensions.PrintReport([
            'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc',
            'main/loss/conf', 'validation/main/map'
        ]),
                       trigger=log_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.extend(extensions.snapshot(),
                       trigger=triggers.ManualScheduleTrigger(
                           args.step + [args.iteration], 'iteration'))
        trainer.extend(extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}'),
                       trigger=(args.iteration, 'iteration'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    trainer.run()
コード例 #3
0
def main():
    parser = argparse.ArgumentParser(
        description='ChainerCV training example: FCIS')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument(
        '--lr',
        '-l',
        type=float,
        default=0.0005,
        help='Default value is for 1 GPU.\n'
        'The learning rate will be multiplied by the number of gpu')
    parser.add_argument('--lr-cooldown-factor',
                        '-lcf',
                        type=float,
                        default=0.1)
    parser.add_argument('--epoch', '-e', type=int, default=42)
    parser.add_argument('--cooldown-epoch', '-ce', type=list, default=[28, 31])
    args = parser.parse_args()

    # chainermn
    comm = chainermn.create_communicator()
    device = comm.intra_rank

    np.random.seed(args.seed)

    # model
    fcis = FCISPSROIAlignResNet101(
        n_fg_class=len(sbd_instance_segmentation_label_names),
        pretrained_model='imagenet',
        iter2=False)
    fcis.use_preset('evaluate')
    model = FCISTrainChain(fcis)
    chainer.cuda.get_device_from_id(device).use()
    model.to_gpu()

    # dataset
    train_dataset = TransformDataset(
        SBDInstanceSegmentationDataset(split='train'),
        ('img', 'mask', 'label', 'bbox', 'scale'), Transform(model.fcis))
    if comm.rank == 0:
        indices = np.arange(len(train_dataset))
    else:
        indices = None
    indices = chainermn.scatter_dataset(indices, comm, shuffle=True)
    train_dataset = train_dataset.slice[indices]
    train_iter = chainer.iterators.SerialIterator(train_dataset, batch_size=1)

    if comm.rank == 0:
        test_dataset = SBDInstanceSegmentationDataset(split='val')
        test_iter = chainer.iterators.SerialIterator(test_dataset,
                                                     batch_size=1,
                                                     repeat=False,
                                                     shuffle=False)

    # optimizer
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(lr=args.lr * comm.size, momentum=0.9),
        comm)
    optimizer.setup(model)

    model.fcis.head.conv1.W.update_rule.add_hook(GradientScaling(3.0))
    model.fcis.head.conv1.b.update_rule.add_hook(GradientScaling(3.0))
    optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005))

    for param in model.params():
        if param.name in ['beta', 'gamma']:
            param.update_rule.enabled = False
    model.fcis.extractor.conv1.disable_update()
    model.fcis.extractor.res2.disable_update()

    updater = chainer.training.updater.StandardUpdater(
        train_iter, optimizer, converter=concat_examples, device=device)

    trainer = chainer.training.Trainer(updater, (args.epoch, 'epoch'),
                                       out=args.out)

    # lr scheduler
    trainer.extend(chainer.training.extensions.ExponentialShift(
        'lr', args.lr_cooldown_factor, init=args.lr * comm.size),
                   trigger=ManualScheduleTrigger(args.cooldown_epoch, 'epoch'))

    if comm.rank == 0:
        # interval
        log_interval = 100, 'iteration'
        plot_interval = 3000, 'iteration'
        print_interval = 20, 'iteration'

        # training extensions
        model_name = model.fcis.__class__.__name__

        trainer.extend(extensions.snapshot_object(
            model.fcis,
            filename='%s_model_iter_{.updater.iteration}.npz' % model_name),
                       trigger=(1, 'epoch'))
        trainer.extend(extensions.observe_lr(), trigger=log_interval)
        trainer.extend(
            extensions.LogReport(log_name='log.json', trigger=log_interval))
        trainer.extend(extensions.PrintReport([
            'iteration',
            'epoch',
            'elapsed_time',
            'lr',
            'main/loss',
            'main/rpn_loc_loss',
            'main/rpn_cls_loss',
            'main/roi_loc_loss',
            'main/roi_cls_loss',
            'main/roi_mask_loss',
            'validation/main/map',
        ]),
                       trigger=print_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

        if extensions.PlotReport.available():
            trainer.extend(extensions.PlotReport(['main/loss'],
                                                 file_name='loss.png',
                                                 trigger=plot_interval),
                           trigger=plot_interval)

        trainer.extend(InstanceSegmentationVOCEvaluator(
            test_iter,
            model.fcis,
            iou_thresh=0.5,
            use_07_metric=True,
            label_names=sbd_instance_segmentation_label_names),
                       trigger=ManualScheduleTrigger(args.cooldown_epoch,
                                                     'epoch'))

        trainer.extend(extensions.dump_graph('main/loss'))

    trainer.run()
コード例 #4
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument("model", help="model file in a log dir")
    parser.add_argument("--gpu", type=int, default=0, help="gpu id")
    parser.add_argument("--save", action="store_true", help="save")
    args = parser.parse_args()

    args_file = path.Path(args.model).parent / "args"
    with open(args_file) as f:
        args_data = json.load(f)
    pprint.pprint(args_data)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()

    model = singleview_3d.models.Model(
        n_fg_class=len(args_data["class_names"][1:]),
        pretrained_resnet18=args_data["pretrained_resnet18"],
        with_occupancy=args_data["with_occupancy"],
        loss=args_data["loss"],
        loss_scale=args_data["loss_scale"],
    )
    if args.gpu >= 0:
        model.to_gpu()

    print(f"==> Loading trained model: {args.model}")
    chainer.serializers.load_npz(args.model, model)
    print("==> Done model loading")

    split = "val"
    dataset = morefusion.datasets.YCBVideoRGBDPoseEstimationDataset(
        split=split)
    dataset_reindexed = morefusion.datasets.YCBVideoRGBDPoseEstimationDatasetReIndexed(  # NOQA
        split=split,
        class_ids=args_data["class_ids"],
    )
    transform = Transform(
        train=False,
        with_occupancy=args_data["with_occupancy"],
    )

    pprint.pprint(args.__dict__)

    # -------------------------------------------------------------------------

    depth2rgb = imgviz.Depth2RGB()
    for index in range(len(dataset)):
        frame = dataset.get_frame(index)

        image_id = dataset._ids[index]
        indices = dataset_reindexed.get_indices_from_image_id(image_id)
        examples = dataset_reindexed[indices]
        examples = [transform(example) for example in examples]

        if not examples:
            continue
        inputs = chainer.dataset.concat_examples(examples, device=args.gpu)

        with chainer.no_backprop_mode() and chainer.using_config(
                "train", False):
            quaternion_pred, translation_pred, confidence_pred = model.predict(
                class_id=inputs["class_id"],
                rgb=inputs["rgb"],
                pcd=inputs["pcd"],
                pitch=inputs.get("pitch"),
                origin=inputs.get("origin"),
                grid_nontarget_empty=inputs.get("grid_nontarget_empty"),
            )

            indices = model.xp.argmax(confidence_pred.array, axis=1)
            quaternion_pred = quaternion_pred[
                model.xp.arange(quaternion_pred.shape[0]), indices]
            translation_pred = translation_pred[
                model.xp.arange(translation_pred.shape[0]), indices]

            reporter = chainer.Reporter()
            reporter.add_observer("main", model)
            observation = {}
            with reporter.scope(observation):
                model.evaluate(
                    class_id=inputs["class_id"],
                    quaternion_true=inputs["quaternion_true"],
                    translation_true=inputs["translation_true"],
                    quaternion_pred=quaternion_pred,
                    translation_pred=translation_pred,
                )

        # TODO(wkentaro)
        observation_new = {}
        for k, v in observation.items():
            if re.match(r"main/add_or_add_s/[0-9]+/.+", k):
                k_new = "/".join(k.split("/")[:-1])
                observation_new[k_new] = v
        observation = observation_new

        print(f"[{index:08d}] {observation}")

        # ---------------------------------------------------------------------

        K = frame["intrinsic_matrix"]
        height, width = frame["rgb"].shape[:2]
        fovy = trimesh.scene.Camera(resolution=(width, height),
                                    focal=(K[0, 0], K[1, 1])).fov[1]

        batch_size = len(inputs["class_id"])
        class_ids = cuda.to_cpu(inputs["class_id"])
        quaternion_pred = cuda.to_cpu(quaternion_pred.array)
        translation_pred = cuda.to_cpu(translation_pred.array)
        quaternion_true = cuda.to_cpu(inputs["quaternion_true"])
        translation_true = cuda.to_cpu(inputs["translation_true"])

        Ts_pred = []
        Ts_true = []
        for i in range(batch_size):
            # T_cad2cam
            T_pred = tf.quaternion_matrix(quaternion_pred[i])
            T_pred[:3, 3] = translation_pred[i]
            T_true = tf.quaternion_matrix(quaternion_true[i])
            T_true[:3, 3] = translation_true[i]
            Ts_pred.append(T_pred)
            Ts_true.append(T_true)

        Ts = dict(true=Ts_true, pred=Ts_pred)

        vizs = []
        depth_viz = depth2rgb(frame["depth"])
        for which in ["true", "pred"]:
            pybullet.connect(pybullet.DIRECT)
            for i, T in enumerate(Ts[which]):
                cad_file = morefusion.datasets.YCBVideoModels().get_cad_file(
                    class_id=class_ids[i])
                morefusion.extra.pybullet.add_model(
                    cad_file,
                    position=tf.translation_from_matrix(T),
                    orientation=tf.quaternion_from_matrix(T)[[1, 2, 3, 0]],
                )
            (
                rgb_rend,
                depth_rend,
                segm_rend,
            ) = morefusion.extra.pybullet.render_camera(
                np.eye(4), fovy, height, width)
            pybullet.disconnect()

            segm_rend = imgviz.label2rgb(segm_rend + 1,
                                         img=frame["rgb"],
                                         alpha=0.7)
            depth_rend = depth2rgb(depth_rend)
            rgb_input = imgviz.tile(cuda.to_cpu(inputs["rgb"]),
                                    border=(255, 255, 255))
            viz = imgviz.tile(
                [
                    frame["rgb"],
                    depth_viz,
                    rgb_input,
                    segm_rend,
                    rgb_rend,
                    depth_rend,
                ],
                (1, 6),
                border=(255, 255, 255),
            )
            viz = imgviz.resize(viz, width=1800)

            if which == "pred":
                text = []
                for class_id in np.unique(class_ids):
                    add = observation[f"main/add_or_add_s/{class_id:04d}"]
                    text.append(f"[{which}] [{class_id:04d}]: "
                                f"add/add_s={add * 100:.1f}cm")
                text = "\n".join(text)
            else:
                text = f"[{which}]"
            viz = imgviz.draw.text_in_rectangle(
                viz,
                loc="lt",
                text=text,
                size=20,
                background=(0, 255, 0),
                color=(0, 0, 0),
            )
            if which == "true":
                viz = imgviz.draw.text_in_rectangle(
                    viz,
                    loc="rt",
                    text="singleview_3d",
                    size=20,
                    background=(255, 0, 0),
                    color=(0, 0, 0),
                )
            vizs.append(viz)
        viz = imgviz.tile(vizs, (2, 1), border=(255, 255, 255))

        if args.save:
            out_file = path.Path(args.model).parent / f"video/{index:08d}.jpg"
            out_file.parent.makedirs_p()
            imgviz.io.imsave(out_file, viz)

        yield viz
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        choices=('ssd300', 'ssd512'),
                        default='ssd300')
    parser.add_argument('--batchsize', type=int, default=32)
    parser.add_argument('--np', type=int, default=8)
    parser.add_argument('--test-batchsize', type=int, default=16)
    parser.add_argument('--iteration', type=int, default=120000)
    parser.add_argument('--step', type=int, nargs='*', default=[80000, 100000])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--out', default='result')
    parser.add_argument('--resume')
    parser.add_argument('--dtype',
                        type=str,
                        choices=dtypes.keys(),
                        default='float32',
                        help='Select the data type of the model')
    parser.add_argument('--model-dir',
                        default=None,
                        type=str,
                        help='Where to store models')
    parser.add_argument('--dataset-dir',
                        default=None,
                        type=str,
                        help='Where to store datasets')
    parser.add_argument('--dynamic-interval',
                        default=None,
                        type=int,
                        help='Interval for dynamic loss scaling')
    parser.add_argument('--init-scale',
                        default=1,
                        type=float,
                        help='Initial scale for ada loss')
    parser.add_argument('--loss-scale-method',
                        default='approx_range',
                        type=str,
                        help='Method for adaptive loss scaling')
    parser.add_argument('--scale-upper-bound',
                        default=16,
                        type=float,
                        help='Hard upper bound for each scale factor')
    parser.add_argument('--accum-upper-bound',
                        default=1024,
                        type=float,
                        help='Accumulated upper bound for all scale factors')
    parser.add_argument('--update-per-n-iteration',
                        default=1,
                        type=int,
                        help='Update the loss scale value per n iteration')
    parser.add_argument('--snapshot-per-n-iteration',
                        default=10000,
                        type=int,
                        help='The frequency of taking snapshots')
    parser.add_argument('--n-uf', default=1e-3, type=float)
    parser.add_argument('--nosanity-check', default=False, action='store_true')
    parser.add_argument('--nouse-fp32-update',
                        default=False,
                        action='store_true')
    parser.add_argument('--profiling', default=False, action='store_true')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='Verbose output')
    args = parser.parse_args()

    # https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#using-multiprocessiterator
    if hasattr(multiprocessing, 'set_start_method'):
        multiprocessing.set_start_method('forkserver')
        p = multiprocessing.Process()
        p.start()
        p.join()

    comm = chainermn.create_communicator('pure_nccl')
    device = comm.intra_rank

    # Set up workspace
    # 12 GB GPU RAM for workspace
    chainer.cuda.set_max_workspace_size(16 * 1024 * 1024 * 1024)
    chainer.global_config.cv_resize_backend = 'cv2'

    # Setup the data type
    # when initializing models as follows, their data types will be casted.
    # Weethave to forbid the usage of cudnn
    if args.dtype != 'float32':
        chainer.global_config.use_cudnn = 'never'
    chainer.global_config.dtype = dtypes[args.dtype]
    print('==> Setting the data type to {}'.format(args.dtype))

    if args.model_dir is not None:
        chainer.dataset.set_dataset_root(args.model_dir)
    if args.model == 'ssd300':
        model = SSD300(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')
    elif args.model == 'ssd512':
        model = SSD512(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')

    model.use_preset('evaluate')

    ######################################
    # Setup model
    #######################################
    # Apply ada loss transform
    recorder = AdaLossRecorder(sample_per_n_iter=100)
    profiler = Profiler()
    sanity_checker = SanityChecker(
        check_per_n_iter=100) if not args.nosanity_check else None
    # Update the model to support AdaLoss
    # TODO: refactorize
    model_ = AdaLossScaled(
        model,
        init_scale=args.init_scale,
        cfg={
            'loss_scale_method': args.loss_scale_method,
            'scale_upper_bound': args.scale_upper_bound,
            'accum_upper_bound': args.accum_upper_bound,
            'update_per_n_iteration': args.update_per_n_iteration,
            'recorder': recorder,
            'profiler': profiler,
            'sanity_checker': sanity_checker,
            'n_uf_threshold': args.n_uf,
            # 'power_of_two': False,
        },
        transforms=[
            AdaLossTransformLinear(),
            AdaLossTransformConvolution2D(),
        ],
        verbose=args.verbose)

    if comm.rank == 0:
        print(model)

    train_chain = MultiboxTrainChain(model_, comm=comm)
    chainer.cuda.get_device_from_id(device).use()

    # to GPU
    model.coder.to_gpu()
    model.extractor.to_gpu()
    model.multibox.to_gpu()

    shared_mem = 100 * 1000 * 1000 * 4

    if args.dataset_dir is not None:
        chainer.dataset.set_dataset_root(args.dataset_dir)
    train = TransformDataset(
        ConcatenatedDataset(VOCBboxDataset(year='2007', split='trainval'),
                            VOCBboxDataset(year='2012', split='trainval')),
        ('img', 'mb_loc', 'mb_label'),
        Transform(model.coder,
                  model.insize,
                  model.mean,
                  dtype=dtypes[args.dtype]))

    if comm.rank == 0:
        indices = np.arange(len(train))
    else:
        indices = None
    indices = chainermn.scatter_dataset(indices, comm, shuffle=True)
    train = train.slice[indices]

    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize //
                                                        comm.size,
                                                        n_processes=8,
                                                        n_prefetch=2,
                                                        shared_mem=shared_mem)

    if comm.rank == 0:  # NOTE: only performed on the first device
        test = VOCBboxDataset(year='2007',
                              split='test',
                              use_difficult=True,
                              return_difficult=True)
        test_iter = chainer.iterators.SerialIterator(test,
                                                     args.test_batchsize,
                                                     repeat=False,
                                                     shuffle=False)

    # initial lr is set to 1e-3 by ExponentialShift
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(), comm)
    if args.dtype == 'mixed16':
        if not args.nouse_fp32_update:
            print('==> Using FP32 update for dtype=mixed16')
            optimizer.use_fp32_update()  # by default use fp32 update

        # HACK: support skipping update by existing loss scaling functionality
        if args.dynamic_interval is not None:
            optimizer.loss_scaling(interval=args.dynamic_interval, scale=None)
        else:
            optimizer.loss_scaling(interval=float('inf'), scale=None)
            optimizer._loss_scale_max = 1.0  # to prevent actual loss scaling

    optimizer.setup(train_chain)
    for param in train_chain.params():
        if param.name == 'b':
            param.update_rule.add_hook(GradientScaling(2))
        else:
            param.update_rule.add_hook(WeightDecay(0.0005))

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=device)
    # if args.dtype == 'mixed16':
    #     updater.loss_scale = 8
    iteration_interval = (args.iteration, 'iteration')

    trainer = training.Trainer(updater, iteration_interval, args.out)
    # trainer.extend(extensions.ExponentialShift('lr', 0.1, init=args.lr),
    #                trigger=triggers.ManualScheduleTrigger(
    #                    args.step, 'iteration'))
    if args.batchsize != 32:
        warmup_attr_ratio = 0.1
        # NOTE: this is confusing but it means n_iter
        warmup_n_epoch = 1000
        lr_shift = chainerlp.extensions.ExponentialShift(
            'lr',
            0.1,
            init=args.lr * warmup_attr_ratio,
            warmup_attr_ratio=warmup_attr_ratio,
            warmup_n_epoch=warmup_n_epoch,
            schedule=args.step)
        trainer.extend(lr_shift, trigger=(1, 'iteration'))

    if comm.rank == 0:
        if not args.profiling:
            trainer.extend(DetectionVOCEvaluator(
                test_iter,
                model,
                use_07_metric=True,
                label_names=voc_bbox_label_names),
                           trigger=triggers.ManualScheduleTrigger(
                               args.step + [args.iteration], 'iteration'))

        log_interval = 10, 'iteration'
        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.observe_lr(), trigger=log_interval)
        trainer.extend(extensions.observe_value(
            'loss_scale',
            lambda trainer: trainer.updater.get_optimizer('main')._loss_scale),
                       trigger=log_interval)

        metrics = [
            'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc',
            'main/loss/conf', 'validation/main/map'
        ]
        if args.dynamic_interval is not None:
            metrics.insert(2, 'loss_scale')

        trainer.extend(extensions.PrintReport(metrics), trigger=log_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.extend(extensions.snapshot(),
                       trigger=(args.snapshot_per_n_iteration, 'iteration'))
        trainer.extend(extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}'),
                       trigger=(args.iteration, 'iteration'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    hook = AdaLossMonitor(sample_per_n_iter=100,
                          verbose=args.verbose,
                          includes=['Grad', 'Deconvolution'])
    recorder.trainer = trainer
    hook.trainer = trainer

    with ExitStack() as stack:
        if comm.rank == 0:
            stack.enter_context(hook)
        trainer.run()

    # store recorded results
    if comm.rank == 0:  # NOTE: only export in the first rank
        recorder.export().to_csv(os.path.join(args.out, 'loss_scale.csv'))
        profiler.export().to_csv(os.path.join(args.out, 'profile.csv'))
        if sanity_checker:
            sanity_checker.export().to_csv(
                os.path.join(args.out, 'sanity_check.csv'))
        hook.export_history().to_csv(os.path.join(args.out, 'grad_stats.csv'))