コード例 #1
0
def extract_image_feature(args):
    try:
        model = nets[args.net]()
        model.to(args.device)
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return 
    logger.info("Extracting {} feature.".format(args.net))

    query_dataloader = get_loader(args.query_data, args.batch_size)
    gallery_dataloader = get_loader(args.gallery_data, args.batch_size)

    checkpointer = Checkpointer(model, save_dir=args.out_dir)
    _ = checkpointer.load(args.checkpoint_path, use_latest=args.checkpoint_path is None)

    model.eval()
    with torch.no_grad():
        for dataloader in [query_dataloader, gallery_dataloader]:
            for batch_imgs, batch_filenames in tqdm(dataloader):
                batch_imgs = batch_imgs.to(args.device)
                batch_features, batch_predicts = model(batch_imgs)
                batch_features = batch_features.cpu().detach().numpy()
                if args.pcaw is not None:
                    batch_features = args.pcaw(batch_features.T, transpose=True)
                batch_predicts = np.argmax(batch_predicts.cpu().detach().numpy(), axis=1)
                for feature_per_image, predict_per_image, name_per_image in zip(batch_features, batch_predicts, batch_filenames):
                    try:
                        out_dir = os.path.join(args.out_dir, 'feat')
                        if not os.path.exists(out_dir):
                            os.makedirs(out_dir)
                        np.save(os.path.join(out_dir, name_per_image + '.npy'), feature_per_image)
                        np.save(os.path.join(out_dir, name_per_image + '.prd.npy'), predict_per_image)
                    except OSError:
                        logger.info("can not write feature with {}.".format(name_per_image))
コード例 #2
0
def test(cfg, local_rank, distributed, logger=None):
    device = torch.device('cuda')
    cpu_device = torch.device('cpu')

    # create model
    logger.info("Creating model \"{}\"".format(cfg.MODEL.ARCHITECTURE))
    model = build_model(cfg).to(device)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255).to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=True,
        )

    # checkpoint
    checkpointer = Checkpointer(model, save_dir=cfg.LOGS.DIR, logger=logger)
    _ = checkpointer.load(f=cfg.MODEL.WEIGHT)

    # data_loader
    logger.info('Loading dataset "{}"'.format(cfg.DATASETS.TEST))
    stage = cfg.DATASETS.TEST.split('_')[-1]
    data_loader = make_data_loader(cfg, stage, distributed)
    dataset_name = cfg.DATASETS.TEST

    metrics = inference(model, criterion, data_loader, dataset_name, True)

    if is_main_process():
        logger.info("Metrics:")
        for k, v in metrics.items():
            logger.info("{}: {}".format(k, v))
コード例 #3
0
 def __init__(self,cfg):
     self.cfg = cfg.clone()
     self.attributes = DerenderAttributes(cfg)
     self.derenderer = Derender(self.cfg, self.attributes)
     checkpointer = Checkpointer(self.derenderer, logger=log)
     checkpointer.load(cfg.MODEL.WEIGHTS)
     self.derenderer = self.derenderer.to("cuda").eval()
def train(cfg, local_rank, distributed):

    num_classes = COCODataset(cfg.data.train[0], cfg.data.train[1]).num_classes
    model = EfficientDet(num_classes=num_classes, model_name=cfg.model.name)
    inp_size = model.config['inp_size']
    device = torch.device(cfg.device)
    model.to(device)

    optimizer = build_optimizer(model, **optimizer_kwargs(cfg))
    lr_scheduler = build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg))

    use_mixed_precision = cfg.dtype == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True)

    arguments = {}
    arguments["iteration"] = 0
    output_dir = cfg.output_dir
    save_to_disk = comm.get_rank() == 0
    checkpointer = Checkpointer(model, optimizer, lr_scheduler, output_dir,
                                save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.model.resume)
    arguments.update(extra_checkpoint_data)

    train_dataloader = build_dataloader(cfg,
                                        inp_size,
                                        is_train=True,
                                        distributed=distributed,
                                        start_iter=arguments["iteration"])

    test_period = cfg.test.test_period
    if test_period > 0:
        val_dataloader = build_dataloader(cfg,
                                          inp_size,
                                          is_train=False,
                                          distributed=distributed)
    else:
        val_dataloader = None

    checkpoint_period = cfg.solver.checkpoint_period
    log_period = cfg.solver.log_period

    do_train(cfg, model, train_dataloader, val_dataloader, optimizer,
             lr_scheduler, checkpointer, device, checkpoint_period,
             test_period, log_period, arguments)

    return model
コード例 #5
0
def train(cfg,
          local_rank,
          distributed,
          logger=None,
          tblogger=None,
          transfer_weight=False,
          change_lr=False):
    device = torch.device('cuda')

    # create model
    logger.info('Creating model "{}"'.format(cfg.MODEL.ARCHITECTURE))
    model = build_model(cfg).to(device)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255).to(device)
    optimizer = make_optimizer(cfg, model)
    # model, optimizer = apex.amp.initialize(model, optimizer, opt_level='O2')
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        # model = apex.parallel.DistributedDataParallel(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=True,
        )

    save_to_disk = get_rank() == 0

    # checkpoint
    arguments = {}
    arguments['iteration'] = 0
    arguments['best_iou'] = 0
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.LOGS.DIR,
                                save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load(
        f=cfg.MODEL.WEIGHT,
        model_weight_only=transfer_weight,
        change_scheduler=change_lr)
    arguments.update(extra_checkpoint_data)

    # data_loader
    logger.info('Loading dataset "{}"'.format(cfg.DATASETS.TRAIN))
    data_loader = make_data_loader(cfg, 'train', distributed)
    data_loader_val = make_data_loader(cfg, 'val', distributed)

    do_train(cfg,
             model=model,
             data_loader=data_loader,
             optimizer=optimizer,
             scheduler=scheduler,
             criterion=criterion,
             checkpointer=checkpointer,
             device=device,
             arguments=arguments,
             tblogger=tblogger,
             data_loader_val=data_loader_val,
             distributed=distributed)
コード例 #6
0
ファイル: main.py プロジェクト: nickline2020/hello
def train(cfg):
    logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    #model.to(cuda_device = 'cuda:9')

    criterion = build_loss(cfg)

    optimizer = build_optimizer(cfg, model)
    scheduler = build_lr_scheduler(cfg, optimizer)

    train_loader = build_data(cfg, is_train=True)
    val_loader = build_data(cfg, is_train=False)

    logger.info(train_loader.dataset)
    logger.info(val_loader.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)

    do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             criterion, checkpointer, device, checkpoint_period, arguments,
             logger)
コード例 #7
0
    def reset(self, output_dir=None, new_params=[], log_flag=True, fixed_seed=True):
        if output_dir == None:
            output_dir = self.cfg.OUTPUT_DIR

        #merge new parameters
        self.cfg.merge_from_list(new_params)


        #save new  configuration
        with open(os.path.join(output_dir,"cfg.yaml"), 'w') as f:
            x = self.cfg.dump(indent=4)
            f.write(x)

        log.info(f'New training run with configuration:\n{self.cfg}\n\n')


        workers = self.cfg.DATALOADER.NUM_WORKERS if not self.cfg.DEBUG else 0
        train_loader = DataLoader(dataset=self.train_dataset, batch_size=self.cfg.DATALOADER.OBJECTS_PER_BATCH,
                                  num_workers=workers, shuffle=True)
        val_loader = DataLoader(dataset=self.val_dataset, batch_size=self.cfg.DATALOADER.VAL_BATCH_SIZE,
                                num_workers=self.cfg.DATALOADER.NUM_WORKERS, shuffle=True)

        # Instantiate training pipeline components
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if fixed_seed:
            torch.manual_seed(1234)
        derenderer = Derender(self.cfg, self.attributes).to(device)

        optimizer = make_optimizer(self.cfg, derenderer)
        scheduler = make_lr_scheduler(self.cfg, optimizer)
        checkpointer = Checkpointer(derenderer, optimizer, scheduler, output_dir, log)
        self.start_iteration = checkpointer.load()

        # Multi-GPU Support
        if device == 'cuda' and not self.cfg.DEBUG:
            gpu_ids = [_ for _ in range(torch.cuda.device_count())]
            derenderer = torch.nn.parallel.DataParallel(derenderer, gpu_ids)

        self.optimizer = optimizer
        self.derenderer = derenderer
        self.scheduler = scheduler
        self.device = device
        self.val_loader = val_loader
        self.train_loader = train_loader
        self.output_dir  = output_dir
        self.checkpointer = checkpointer
コード例 #8
0
    def __init__(self,cfg):
        # TEST = ("intphys_dev-meta_O1",
        #         "intphys_dev-meta_O2",
        #         "intphys_dev-meta_O3")
        val_datasets = {k: DynamicsDataset(cfg, cfg.DATASETS.TEST,k)
                        for k in cfg.ATTRIBUTES_KEYS}
        # train_dataset = DynamicsDataset(cfg, cfg.DATASETS.TRAIN)
        train_dataset = DynamicsDataset(cfg, cfg.DATASETS.TRAIN,
                                        # cfg.DATASETS.ATTRIBUTES_KEYS[0]
                                        #TODO:(YILUN) if you use any other types of attributes seen in
                                        #cfg.ATTRIBUTES_KEYS, you will train on outputs of the derender
                                        #the object_ids will still be ok.
                                        "attributes")
        #TODO: add torch data parallel or distributed or sth
        model = build_dynamics_model(cfg, val_dataset).cuda().train()
        ckpt = torch.load("/all/home/yilundu/repos/cora-derenderer/output/intphys/dynamics/exp_00073/model_0052474.pth")

        model.load_state_dict(ckpt['models'])
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)
        checkpointer = Checkpointer(model, optimizer, scheduler, cfg.OUTPUT_DIR, log)

        workers = 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKERS
        train_loader = DataLoader(train_dataset, batch_size=cfg.DATALOADER.BATCH_SIZE,
                                  num_workers=workers)
        val_loader = DataLoader(val_dataset, batch_size=cfg.DATALOADER.BATCH_SIZE,
                                  num_workers=workers, shuffle=False)

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.checkpointer = checkpointer
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.model = model
        self.output_dir = cfg.OUTPUT_DIR
        self.start_iteration = 0
        self.cfg = cfg
コード例 #9
0
criterion_eval = get_criterion(cfg, train=False)
criterion_eval.cuda()
optimizer = None
scheduler = None
if not cfg.EVALUATE:
    criterion = get_criterion(cfg)
    criterion.cuda()
    optimizer = get_opt(cfg, net, resume=iteration > 0)
    scheduler = get_lr_scheduler(cfg, optimizer, last_iter=iteration)

##################### make a checkpoint ############################
best_acc = 0.0
checkpointer = Checkpointer(net,
                            cfg.MODEL.ARCH,
                            best_acc=best_acc,
                            optimizer=optimizer,
                            scheduler=scheduler,
                            save_dir=cfg.OUTPUT_DIR,
                            is_test=cfg.EVALUATE,
                            only_save_last=cfg.ONLY_SAVE_LAST)

filepath = cfg.MODEL.MODEL_PATH
if not os.path.isfile(filepath):
    filepath = os.path.join(cfg.DATA.DATA_DIR, cfg.MODEL.MODEL_PATH)
extra_checkpoint_data = checkpointer.load(filepath)

############## tensorboard writers #############################
tb_log_dir = os.path.join(args.output_dir, 'tf_logs')
train_tb_log_dir = os.path.join(tb_log_dir, 'train_logs')
task_names = [
    task_name.replace('.yaml', '').replace('/', '_')
    for task_name in cfg.DATA.TEST
コード例 #10
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger(cfg.NAME)
    # build model
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    # build solver
    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {"iteration": 0}

    save_dir = os.path.join(cfg.CHECKPOINTER.DIR, cfg.CHECKPOINTER.NAME)

    save_to_disk = get_rank() == 0
    checkpointer = Checkpointer(
        model=model, optimizer=optimizer, scheduler=scheduler,
        save_dir=save_dir, save_to_disk=save_to_disk, logger=logger
    )
    extra_checkpoint_data = checkpointer.load(cfg.CHECKPOINTER.LOAD_NAME)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    evaluate = cfg.SOLVER.EVALUATE
    if evaluate:
        synchronize()
        data_loader_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True)
        synchronize()
    else:
        data_loader_val = None

    save_to_disk = get_rank() == 0
    if cfg.SUMMARY_WRITER and save_to_disk:
        save_dir = os.path.join(cfg.CHECKPOINTER.DIR, cfg.CHECKPOINTER.NAME)
        summary_writer = make_summary_writer(cfg.SUMMARY_WRITER, save_dir, model_name=cfg.MODEL.NAME)
    else:
        summary_writer = None

    do_train(
        cfg,
        model,
        data_loader,
        data_loader_val,
        optimizer,
        scheduler,
        checkpointer,
        device,
        arguments,
        summary_writer
    )

    return model
コード例 #11
0
    def __init__(self, cfg):
        """ construct the model here
        """
        super(Modelbuilder, self).__init__()
        if cfg.DATASETS.TASK in ['multiview_keypoint']:
            self.reference = registry.BACKBONES[cfg.BACKBONE.BODY](cfg)
            if cfg.EPIPOLAR.PRETRAINED or not cfg.EPIPOLAR.SHARE_WEIGHTS:
                backbone, backbone_dir = BackboneCatalog.get(cfg.BACKBONE.BODY)
            if cfg.EPIPOLAR.PRETRAINED:
                checkpointer = Checkpointer(
                    model=self.reference,
                    save_dir=backbone_dir,
                )
                checkpointer.load('model.pth', prefix='backbone.module.')
            if not cfg.VIS.FLOPS:
                self.reference = nn.DataParallel(self.reference)
            if cfg.EPIPOLAR.SHARE_WEIGHTS:
                self.backbone = self.reference
            else:
                self.backbone = registry.BACKBONES[backbone](cfg)
                # load fixed weights for the other views
                checkpointer = Checkpointer(
                    model=self.backbone,
                    save_dir=backbone_dir,
                )
                checkpointer.load('model.pth', prefix='backbone.module.')
                self.backbone = nn.DataParallel(self.backbone)
            if cfg.BACKBONE.SYNC_BN:
                self.reference = convert_model(self.reference)
                self.backbone = convert_model(self.backbone)
            if cfg.KEYPOINT.LOSS == 'joint':
                # if cfg.WEIGHTS != "":
                #     checkpointer = Checkpointer(model=self.reference)
                #     checkpointer.load(cfg.WEIGHTS, prefix='backbone.module.', prefix_replace='reference.module.')
                # self.reference = nn.DataParallel(self.reference)
                print('h36m special setting: JointsMSELoss')
                self.criterion = JointsMSELoss()
            elif cfg.KEYPOINT.LOSS == 'smoothmse':
                print('h36m special setting: smoothMSE')
                self.criterion = KeypointsMSESmoothLoss()
            elif cfg.KEYPOINT.LOSS == 'mse':
                _criterion = MaskedMSELoss()
                self.criterion = lambda targets, outputs: compute_stage_loss(
                    _criterion, targets, outputs)
        elif cfg.DATASETS.TASK == 'keypoint':
            self.backbone = build_backbone(cfg)
            self.backbone = nn.DataParallel(self.backbone)
            if 'h36m' in cfg.OUTPUT_DIR:
                print('h36m special setting: JointsMSELoss')
                self.criterion = JointsMSELoss()
            else:
                _criterion = MaskedMSELoss()
                self.criterion = lambda targets, outputs: compute_stage_loss(
                    _criterion, targets, outputs)
        elif cfg.DATASETS.TASK == 'keypoint_lifting_rot':
            self.backbone = build_backbone(cfg)
            self.backbone = nn.DataParallel(self.backbone)
            self.liftingnet = build_liftingnet()
        elif cfg.DATASETS.TASK == 'keypoint_lifting_direct':
            self.backbone = build_backbone(cfg)
            backbone, backbone_dir = BackboneCatalog.get(cfg.BACKBONE.BODY)
            checkpointer = Checkpointer(
                model=self.backbone,
                save_dir=backbone_dir,
            )
            checkpointer.load('model.pth', prefix='backbone.module.')
            self.backbone = nn.DataParallel(self.backbone)
            self.liftingnet = build_liftingnet()
            self.liftingnet = nn.DataParallel(self.liftingnet)
        elif cfg.DATASETS.TASK == 'img_lifting_rot':
            self.backbone = build_backbone(cfg)
            self.liftingnet = build_liftingnet(
                in_channels=self.backbone.out_channels)
            self.backbone = nn.DataParallel(self.backbone)
            self.liftingnet = nn.DataParallel(self.liftingnet)

        elif cfg.DATASETS.TASK == 'multiview_img_lifting_rot':
            self.reference = registry.BACKBONES[cfg.BACKBONE.BODY](cfg)
            backbone, backbone_dir = BackboneCatalog.get(cfg.BACKBONE.BODY)
            if True:
                checkpointer = Checkpointer(
                    model=self.reference,
                    save_dir=backbone_dir,
                )
                checkpointer.load('model.pth', prefix='backbone.module.')
            self.reference = nn.DataParallel(self.reference)
            if cfg.EPIPOLAR.SHARE_WEIGHTS:
                self.backbone = self.reference
            else:
                self.backbone = registry.BACKBONES[backbone](cfg)
                # load fixed weights for the other views
                checkpointer = Checkpointer(
                    model=self.backbone,
                    save_dir=backbone_dir,
                )
                checkpointer.load('model.pth', prefix='backbone.module.')
                self.backbone = nn.DataParallel(self.backbone)

            self.liftingnet = build_liftingnet(
                in_channels=self.backbone.out_channels)
            self.liftingnet = nn.DataParallel(self.liftingnet)

        elif cfg.BACKBONE.ENABLED:
            self.backbone = build_backbone(cfg)
            self.liftingnet = build_liftingnet(
                in_channels=self.backbone.out_channels)
            self.backbone = nn.DataParallel(self.backbone)
            self.liftingnet = nn.DataParallel(self.liftingnet)
        elif cfg.LIFTING.ENABLED:
            self.liftingnet = build_liftingnet()
            self.liftingnet = nn.DataParallel(self.liftingnet)
        else:
            raise NotImplementedError

        if cfg.KEYPOINT.TRIANGULATION == 'rpsm' and 'h36m' in cfg.OUTPUT_DIR and cfg.DATASETS.TASK in [
                'keypoint', 'multiview_keypoint'
        ]:
            import pickle
            self.device = torch.device('cuda:0')
            pairwise_file = cfg.PICT_STRUCT.PAIRWISE_FILE
            with open(pairwise_file, 'rb') as f:
                self.pairwise = pickle.load(f)['pairwise_constrain']
            for k, v in self.pairwise.items():
                self.pairwise[k] = torch.as_tensor(v.todense().astype(
                    np.float),
                                                   device=self.device,
                                                   dtype=torch.float)
コード例 #12
0
def main():
    args = parse_args()
    cfg = get_default_cfg()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    dataset = COCODataset(cfg.data.test[0], cfg.data.test[1])
    num_classes = dataset.num_classes
    label_map = dataset.labels
    model = EfficientDet(num_classes=num_classes, model_name=cfg.model.name)
    device = torch.device(cfg.device)
    model.to(device)
    model.eval()

    inp_size = model.config['inp_size']
    transforms = build_transforms(False, inp_size=inp_size)

    output_dir = cfg.output_dir
    checkpointer = Checkpointer(model, None, None, output_dir, True)
    checkpointer.load(args.ckpt)

    images = []
    if args.img:
        if osp.isdir(args.img):
            for filename in os.listdir(args.img):
                if is_valid_file(filename):
                    images.append(osp.join(args.img, filename))
        else:
            images = [args.img]

    for img_path in images:
        img = cv2.imread(img_path)
        img = inference(model,
                        img,
                        label_map,
                        score_thr=args.score_thr,
                        transforms=transforms)
        save_path = osp.join(args.save, osp.basename(img_path))
        cv2.imwrite(save_path, img)

    if args.vid:
        vCap = cv2.VideoCapture(args.v)
        fps = int(vCap.get(cv2.CAP_PROP_FPS))
        height = int(vCap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        width = int(vCap.get(cv2.CAP_PROP_FRAME_WIDTH))
        size = (width, height)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        save_path = osp.join(args.save, osp.basename(args.v))
        vWrt = cv2.VideoWriter(save_path, fourcc, fps, size)
        while True:
            flag, frame = vCap.read()
            if not flag:
                break
            frame = inference(model,
                              frame,
                              label_map,
                              score_thr=args.score_thr,
                              transforms=transforms)
            vWrt.write(frame)

        vCap.release()
        vWrt.release()
コード例 #13
0
def test(cfg, model=None):
    torch.cuda.empty_cache()  # TODO check if it helps
    cpu_device = torch.device("cpu")
    if cfg.VIS.FLOPS:
        # device = cpu_device
        device = torch.device("cuda:0")
    else:
        device = torch.device(cfg.DEVICE)
    if model is None:
        # load model from outputs
        model = Modelbuilder(cfg)
        model.to(device)
        checkpointer = Checkpointer(model, save_dir=cfg.OUTPUT_DIR)
        _ = checkpointer.load(cfg.WEIGHTS)
    data_loaders = make_data_loader(cfg, is_train=False)
    if cfg.VIS.FLOPS:
        model.eval()
        from thop import profile
        for idx, batchdata in enumerate(data_loaders[0]):
            with torch.no_grad():
                flops, params = profile(
                    model,
                    inputs=({
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    }, False))
                print('flops', flops, 'params', params)
                exit()
    if cfg.TEST.RECOMPUTE_BN:
        tmp_data_loader = make_data_loader(cfg,
                                           is_train=True,
                                           dataset_list=cfg.DATASETS.TEST)
        model.train()
        for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
            with torch.no_grad():
                model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=True)
        #cnt = 0
        #while cnt < 1000:
        #    for idx, batchdata in enumerate(tqdm(tmp_data_loader)):
        #        with torch.no_grad():
        #            model({k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batchdata.items()}, is_train=True)
        #        cnt += 1
        checkpointer.save("model_bn")
        model.eval()
    elif cfg.TEST.TRAIN_BN:
        model.train()
    else:
        model.eval()
    dataset_names = cfg.DATASETS.TEST
    meters = MetricLogger()

    #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
    #    all_preds = np.zeros((len(data_loaders), cfg.KEYPOINT.NUM_PTS, 3), dtype=np.float32)
    cpu = lambda x: x.to(cpu_device).numpy() if isinstance(x, torch.Tensor
                                                           ) else x

    logger = setup_logger("tester", cfg.OUTPUT_DIR)
    for data_loader, dataset_name in zip(data_loaders, dataset_names):
        print('Loading ', dataset_name)
        dataset = data_loader.dataset

        logger.info("Start evaluation on {} dataset({} images).".format(
            dataset_name, len(dataset)))
        total_timer = Timer()
        total_timer.tic()

        predictions = []
        #if 'h36m' in cfg.OUTPUT_DIR:
        #    err_joints = 0
        #else:
        err_joints = np.zeros((cfg.TEST.IMS_PER_BATCH, int(cfg.TEST.MAX_TH)))
        total_joints = 0

        for idx, batchdata in enumerate(tqdm(data_loader)):
            if cfg.VIS.VIDEO and not 'h36m' in cfg.OUTPUT_DIR:
                for k, v in batchdata.items():
                    try:
                        #good 1 2 3 4 5 6 7 8 12 16 30
                        # 4 17.4 vs 16.5
                        # 30 41.83200 vs 40.17562
                        #bad 0 22
                        #0 43.78544 vs 45.24059
                        #22 43.01385 vs 43.88636
                        vis_idx = 16
                        batchdata[k] = v[:, vis_idx, None]
                    except:
                        pass
            if cfg.VIS.VIDEO_GT:
                for k, v in batchdata.items():
                    try:
                        vis_idx = 30
                        batchdata[k] = v[:, vis_idx:vis_idx + 2]
                    except:
                        pass
                joints = cpu(batchdata['points-2d'].squeeze())[0]
                orig_img = de_transform(
                    cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())[0]
                    plot_two_hand_2d(joints, ax, visibility)
                    # plot_two_hand_2d(joints, ax)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join("outs", "video_gt", dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                continue
            #print('batchdatapoints-3d', batchdata['points-3d'])
            batch_size = cfg.TEST.IMS_PER_BATCH
            with torch.no_grad():
                loss_dict, metric_dict, output = model(
                    {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batchdata.items()
                    },
                    is_train=False)
            meters.update(**prefix_dict(loss_dict, dataset_name))
            meters.update(**prefix_dict(metric_dict, dataset_name))
            # udpate err_joints
            if cfg.VIS.VIDEO:
                joints = cpu(output['batch_locs'].squeeze())
                if joints.shape[0] == 1:
                    joints = joints[0]
                try:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])[0][0])
                except:
                    orig_img = de_transform(
                        cpu(batchdata['img'].squeeze()[None, ...])
                        [0])  # fig = plt.figure()
                # ax = fig.add_subplot(111)
                ax = display_image_in_actual_size(orig_img.shape[1],
                                                  orig_img.shape[2])
                if 'h36m' in cfg.OUTPUT_DIR:
                    draw_2d_pose(joints, ax)
                    orig_img = orig_img[::-1]
                else:
                    visibility = cpu(batchdata['visibility'].squeeze())
                    if visibility.shape[0] == 1:
                        visibility = visibility[0]
                    plot_two_hand_2d(joints, ax, visibility)
                ax.imshow(orig_img.transpose((1, 2, 0)))
                ax.axis('off')
                output_folder = os.path.join(cfg.OUTPUT_DIR, "video",
                                             dataset_name)
                mkdir(output_folder)
                plt.savefig(os.path.join(output_folder, "%08d" % idx),
                            bbox_inches="tight",
                            pad_inches=0)
                plt.cla()
                plt.clf()
                plt.close()
                # plt.show()

            if cfg.TEST.PCK and cfg.DOTEST:
                #if 'h36m' in cfg.OUTPUT_DIR:
                #    err_joints += metric_dict['accuracy'] * output['total_joints']
                #    total_joints += output['total_joints']
                #    # all_preds
                #else:
                for i in range(batch_size):
                    err_joints = np.add(err_joints, output['err_joints'])
                    total_joints += sum(output['total_joints'])

            if idx % cfg.VIS.SAVE_PRED_FREQ == 0 and (
                    cfg.VIS.SAVE_PRED_LIMIT == -1
                    or idx < cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ):
                # print(meters)
                for i in range(batch_size):
                    predictions.append((
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in batchdata.items()
                        },
                        {
                            k: (cpu(v[i]) if not isinstance(v, int) else v)
                            for k, v in output.items()
                        },
                    ))
            if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx > cfg.VIS.SAVE_PRED_LIMIT * cfg.VIS.SAVE_PRED_FREQ:
                break

            # if not cfg.DOTRAIN and cfg.SAVE_PRED:
            #     if cfg.VIS.SAVE_PRED_LIMIT != -1 and idx < cfg.VIS.SAVE_PRED_LIMIT:
            #         for i in range(batch_size):
            #             predictions.append(
            #                     (
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in batchdata.items()},
            #                         {k: (cpu(v[i]) if not isinstance(v, int) else v) for k, v in output.items()},
            #                     )
            #             )
            #     if idx == cfg.VIS.SAVE_PRED_LIMIT:
            #         break
        #if cfg.TEST.PCK and cfg.DOTEST and 'h36m' in cfg.OUTPUT_DIR:
        #    logger.info('accuracy0.5: {}'.format(err_joints/total_joints))
        # dataset.evaluate(all_preds)
        # name_value, perf_indicator = dataset.evaluate(all_preds)
        # names = name_value.keys()
        # values = name_value.values()
        # num_values = len(name_value)
        # logger.info(' '.join(['| {}'.format(name) for name in names]) + ' |')
        # logger.info('|---' * (num_values) + '|')
        # logger.info(' '.join(['| {:.3f}'.format(value) for value in values]) + ' |')

        total_time = total_timer.toc()
        total_time_str = get_time_str(total_time)
        logger.info("Total run time: {} ".format(total_time_str))

        if cfg.OUTPUT_DIR:  #and cfg.VIS.SAVE_PRED:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            torch.save(predictions,
                       os.path.join(output_folder, cfg.VIS.SAVE_PRED_NAME))
            if cfg.DOTEST and cfg.TEST.PCK:
                print(err_joints.shape)
                torch.save(err_joints * 1.0 / total_joints,
                           os.path.join(output_folder, "pck.pth"))

    logger.info("{}".format(str(meters)))

    model.train()
    return meters.get_all_avg()
コード例 #14
0
def train(cfg, args):
    train_set = DatasetCatalog.get(cfg.DATASETS.TRAIN, args)
    val_set = DatasetCatalog.get(cfg.DATASETS.VAL, args)
    train_loader = DataLoader(train_set,
                              cfg.SOLVER.IMS_PER_BATCH,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            cfg.SOLVER.IMS_PER_BATCH,
                            num_workers=cfg.DATALOADER.NUM_WORKERS,
                            shuffle=True)

    gpu_ids = [_ for _ in range(torch.cuda.device_count())]
    model = build_model(cfg)
    model.to("cuda")
    model = torch.nn.parallel.DataParallel(
        model, gpu_ids) if not args.debug else model

    logger = logging.getLogger("train_logger")
    logger.info("Start training")
    train_metrics = MetricLogger(delimiter="  ")
    max_iter = cfg.SOLVER.MAX_ITER
    output_dir = cfg.OUTPUT_DIR

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)
    checkpointer = Checkpointer(model, optimizer, scheduler, output_dir,
                                logger)
    start_iteration = checkpointer.load() if not args.debug else 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    validation_period = cfg.SOLVER.VALIDATION_PERIOD
    summary_writer = SummaryWriter(log_dir=os.path.join(output_dir, "summary"))
    visualizer = train_set.visualizer(cfg.VISUALIZATION)(summary_writer)

    model.train()
    start_training_time = time.time()
    last_batch_time = time.time()

    for iteration, inputs in enumerate(cycle(train_loader), start_iteration):
        data_time = time.time() - last_batch_time
        iteration = iteration + 1
        scheduler.step()

        inputs = to_cuda(inputs)
        outputs = model(inputs)

        loss_dict = gather_loss_dict(outputs)
        loss = loss_dict["loss"]
        train_metrics.update(**loss_dict)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - last_batch_time
        last_batch_time = time.time()
        train_metrics.update(time=batch_time, data=data_time)

        eta_seconds = train_metrics.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                train_metrics.delimiter.join([
                    "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}",
                    "max mem: {memory:.0f}"
                ]).format(eta=eta_string,
                          iter=iteration,
                          meters=str(train_metrics),
                          lr=optimizer.param_groups[0]["lr"],
                          memory=torch.cuda.max_memory_allocated() / 1024.0 /
                          1024.0))
            summary_writer.add_scalars("train", train_metrics.mean, iteration)

        if iteration % 100 == 0:
            visualizer.visualize(inputs, outputs, iteration)

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration))

        if iteration % validation_period == 0:
            with torch.no_grad():
                val_metrics = MetricLogger(delimiter="  ")
                for i, inputs in enumerate(val_loader):
                    data_time = time.time() - last_batch_time

                    inputs = to_cuda(inputs)
                    outputs = model(inputs)

                    loss_dict = gather_loss_dict(outputs)
                    val_metrics.update(**loss_dict)

                    batch_time = time.time() - last_batch_time
                    last_batch_time = time.time()
                    val_metrics.update(time=batch_time, data=data_time)

                    if i % 20 == 0 or i == cfg.SOLVER.VALIDATION_LIMIT:
                        logger.info(
                            val_metrics.delimiter.join([
                                "VALIDATION", "eta: {eta}", "iter: {iter}",
                                "{meters}"
                            ]).format(eta=eta_string,
                                      iter=iteration,
                                      meters=str(val_metrics)))

                    if i == cfg.SOLVER.VALIDATION_LIMIT:
                        summary_writer.add_scalars("val", val_metrics.mean,
                                                   iteration)
                        break
        if iteration == max_iter:
            break

    checkpointer.save("model_{:07d}".format(max_iter))
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
コード例 #15
0
def train(cfg, output_dir=""):
    # logger = logging.getLogger("ModelZoo.trainer")

    # build model
    set_random_seed(cfg.RNG_SEED)
    model, loss_fn, metric_fn = build_model(cfg)
    logger.info("Build model:\n{}".format(str(model)))
    model = nn.DataParallel(model).cuda()

    # build optimizer
    optimizer = build_optimizer(cfg, model)

    # build lr scheduler
    scheduler = build_scheduler(cfg, optimizer)

    # build checkpointer
    checkpointer = Checkpointer(model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                save_dir=output_dir,
                                logger=logger)

    checkpoint_data = checkpointer.load(cfg.GLOBAL.TRAIN.WEIGHT,
                                        resume=cfg.AUTO_RESUME)
    ckpt_period = cfg.GLOBAL.TRAIN.CHECKPOINT_PERIOD

    # build data loader
    train_data_loader = build_data_loader(cfg,
                                          cfg.GLOBAL.DATASET,
                                          mode="train")
    val_period = cfg.GLOBAL.VAL.VAL_PERIOD
    # val_data_loader = build_data_loader(cfg, mode="val") if val_period > 0 else None

    # build tensorboard logger (optionally by comment)
    tensorboard_logger = TensorboardLogger(output_dir)

    # train
    max_epoch = cfg.GLOBAL.MAX_EPOCH
    start_epoch = checkpoint_data.get("epoch", 0)
    # best_metric_name = "best_{}".format(cfg.TRAIN.VAL_METRIC)
    # best_metric = checkpoint_data.get(best_metric_name, None)
    logger.info("Start training from epoch {}".format(start_epoch))
    for epoch in range(start_epoch, max_epoch):
        cur_epoch = epoch + 1
        scheduler.step()
        start_time = time.time()
        train_meters = train_model(
            model,
            loss_fn,
            metric_fn,
            data_loader=train_data_loader,
            optimizer=optimizer,
            curr_epoch=epoch,
            tensorboard_logger=tensorboard_logger,
            log_period=cfg.GLOBAL.TRAIN.LOG_PERIOD,
            output_dir=output_dir,
        )
        epoch_time = time.time() - start_time
        logger.info("Epoch[{}]-Train {}  total_time: {:.2f}s".format(
            cur_epoch, train_meters.summary_str, epoch_time))

        # checkpoint
        if cur_epoch % ckpt_period == 0 or cur_epoch == max_epoch:
            checkpoint_data["epoch"] = cur_epoch
            # checkpoint_data[best_metric_name] = best_metric
            checkpointer.save("model_{:03d}".format(cur_epoch),
                              **checkpoint_data)
        '''
        # validate
        if val_period < 1:
            continue
        if cur_epoch % val_period == 0 or cur_epoch == max_epoch:
            val_meters = validate_model(model,
                                        loss_fn,
                                        metric_fn,
                                        image_scales=cfg.MODEL.VAL.IMG_SCALES,
                                        inter_scales=cfg.MODEL.VAL.INTER_SCALES,
                                        isFlow=(cur_epoch > cfg.SCHEDULER.INIT_EPOCH),
                                        data_loader=val_data_loader,
                                        curr_epoch=epoch,
                                        tensorboard_logger=tensorboard_logger,
                                        log_period=cfg.TEST.LOG_PERIOD,
                                        output_dir=output_dir,
                                        )
            logger.info("Epoch[{}]-Val {}".format(cur_epoch, val_meters.summary_str))

            # best validation
            cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg
            if best_metric is None or cur_metric > best_metric:
                best_metric = cur_metric
                checkpoint_data["epoch"] = cur_epoch
                checkpoint_data[best_metric_name] = best_metric
                checkpointer.save("model_best", **checkpoint_data)
        '''

    logger.info("Train Finish!")
    # logger.info("Best val-{} = {}".format(cfg.TRAIN.VAL_METRIC, best_metric))

    return model
コード例 #16
0
    chkpt_dir, log_dir, tb_dir = setup_exp(cfg.SYSTEM.SAVE_ROOT,
                                           cfg.SYSTEM.EXP_NAME, args.clean_run)
    print(f'chkpr_dir:{chkpt_dir}, log_dir:{log_dir}, tb_dir:{tb_dir}')

    writer = tensorboard.SummaryWriter(log_dir=tb_dir)
    logger = setup_logger(cfg.SYSTEM.EXP_NAME, log_dir, 0)

    logger.info(f'cfg: {str(cfg)}')

    glimpse_network = GlimpseNetwork(cfg)
    model = CoreNetwork(cfg, glimpse_network)
    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.SOLVER.LR,
                           weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                           betas=cfg.SOLVER.BETAS)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.SOLVER.STEP_SIZE)
    chkpt = Checkpointer(model,
                         optimizer,
                         scheduler,
                         chkpt_dir,
                         save_to_disk=True,
                         logger=logger)

    train_loader = get_loader(cfg, 'train')
    test_loader = get_loader(cfg, 'test')
    val_loader = get_loader(cfg, 'val')

    chkpt.load()
    loader = [train_loader, val_loader]
    train(cfg, model, loader, optimizer, scheduler, writer)
コード例 #17
0
parser.add_argument('--val_num', default=-1, type=int, help='Number of validation images, -1 for all.')
parser.add_argument('--score_voting', action='store_true', default=False, help='Using score voting.')
parser.add_argument('--improved_coco', action='store_true', default=False, help='Improved COCO-EVAL written by myself.')

args = parser.parse_args()
cfg = get_config(args)

model = PAA(cfg)
model.train().cuda()  # broadcast_buffers is True if BN is used
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False)

# if cfg.MODEL.USE_SYNCBN:  # TODO: figure this out
#     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

optim = Optimizer(model, cfg)
checkpointer = Checkpointer(cfg, model.module, optim.optimizer)
start_iter = int(cfg.resume.split('_')[-1].split('.')[0]) if cfg.resume else 0
data_loader = make_data_loader(cfg, start_iter=start_iter)
max_iter = len(data_loader) - 1
timer.reset()
main_gpu = dist.get_rank() == 0
num_gpu = dist.get_world_size()

for i, (img_list_batch, box_list_batch) in enumerate(data_loader, start_iter):
    if main_gpu and i == start_iter + 1:
        timer.start()

    optim.update_lr(step=i)

    img_tensor_batch = torch.stack([aa.img for aa in img_list_batch], dim=0).cuda()
    for box_list in box_list_batch:
コード例 #18
0
def train(cfg):
    device = torch.device(cfg.DEVICE)
    arguments = {}
    arguments["epoch"] = 0
    if not cfg.DATALOADER.BENCHMARK:
        model = Modelbuilder(cfg)
        print(model)
        model.to(device)
        model.float()
        optimizer, scheduler = make_optimizer(cfg, model)
        checkpointer = Checkpointer(model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    save_dir=cfg.OUTPUT_DIR)
        extra_checkpoint_data = checkpointer.load(
            cfg.WEIGHTS,
            prefix=cfg.WEIGHTS_PREFIX,
            prefix_replace=cfg.WEIGHTS_PREFIX_REPLACE,
            loadoptimizer=cfg.WEIGHTS_LOAD_OPT)
        arguments.update(extra_checkpoint_data)
        model.train()

    logger = setup_logger("trainer", cfg.FOLDER_NAME)
    if cfg.TENSORBOARD.USE:
        writer = SummaryWriter(cfg.FOLDER_NAME)
    else:
        writer = None
    meters = MetricLogger(writer=writer)
    start_training_time = time.time()
    end = time.time()
    start_epoch = arguments["epoch"]
    max_epoch = cfg.SOLVER.MAX_EPOCHS

    if start_epoch == max_epoch:
        logger.info("Final model exists! No need to train!")
        test(cfg, model)
        return

    data_loader = make_data_loader(
        cfg,
        is_train=True,
    )
    size_epoch = len(data_loader)
    max_iter = size_epoch * max_epoch
    logger.info("Start training {} batches/epoch".format(size_epoch))

    for epoch in range(start_epoch, max_epoch):
        arguments["epoch"] = epoch
        #batchcnt = 0
        for iteration, batchdata in enumerate(data_loader):
            cur_iter = size_epoch * epoch + iteration
            data_time = time.time() - end

            batchdata = {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in batchdata.items()
            }

            if not cfg.DATALOADER.BENCHMARK:
                loss_dict, metric_dict = model(batchdata)
                # print(loss_dict, metric_dict)
                optimizer.zero_grad()
                loss_dict['loss'].backward()
                optimizer.step()

            batch_time = time.time() - end
            end = time.time()

            meters.update(time=batch_time, data=data_time, iteration=cur_iter)

            if cfg.DATALOADER.BENCHMARK:
                logger.info(
                    meters.delimiter.join([
                        "iter: {iter}",
                        "{meters}",
                    ]).format(
                        iter=iteration,
                        meters=str(meters),
                    ))
                continue

            eta_seconds = meters.time.global_avg * (max_iter - cur_iter)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % cfg.LOG_FREQ == 0:
                meters.update(iteration=cur_iter, **loss_dict)
                meters.update(iteration=cur_iter, **metric_dict)
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "epoch: {epoch}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        # "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        epoch=epoch,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                        # memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                    ))
        #UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler.step()

        if (epoch + 1) % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            arguments["epoch"] += 1
            checkpointer.save("model_{:03d}".format(epoch), **arguments)
        if epoch == max_epoch - 1:
            arguments['epoch'] = max_epoch
            checkpointer.save("model_final", **arguments)

            total_training_time = time.time() - start_training_time
            total_time_str = str(
                datetime.timedelta(seconds=total_training_time))
            logger.info("Total training time: {} ({:.4f} s / epoch)".format(
                total_time_str,
                total_training_time / (max_epoch - start_epoch)))
        if epoch == max_epoch - 1 or ((epoch + 1) % cfg.EVAL_FREQ == 0):
            results = test(cfg, model)
            meters.update(is_train=False, iteration=cur_iter, **results)
コード例 #19
0
def train(args):
    try:
        model = nets[args.net](args.margin, args.omega, args.use_hardtriplet)
        model.to(args.device)
    except Exception as e:
        logger.error("Initialize {} error: {}".format(args.net, e))
        return
    logger.info("Training {}.".format(args.net))

    optimizer = make_optimizer(args, model)
    scheduler = make_scheduler(args, optimizer)

    if args.device != torch.device("cpu"):
        amp_opt_level = 'O1' if args.use_amp else 'O0'
        model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    arguments = {}
    arguments.update(vars(args))
    arguments["itr"] = 0
    checkpointer = Checkpointer(model, 
                                optimizer=optimizer, 
                                scheduler=scheduler,
                                save_dir=args.out_dir, 
                                save_to_disk=True)
    ## load model from pretrained_weights or training break_point.
    extra_checkpoint_data = checkpointer.load(args.pretrained_weights)
    arguments.update(extra_checkpoint_data)
    
    batch_size = args.batch_size
    fashion = FashionDataset(item_num=args.iteration_num*batch_size)
    dataloader = DataLoader(dataset=fashion, shuffle=True, num_workers=8, batch_size=batch_size)

    model.train()
    meters = MetricLogger(delimiter=", ")
    max_itr = len(dataloader)
    start_itr = arguments["itr"] + 1
    itr_start_time = time.time()
    training_start_time = time.time()
    for itr, batch_data in enumerate(dataloader, start_itr):
        batch_data = (bd.to(args.device) for bd in batch_data)
        loss_dict = model.loss(*batch_data)
        optimizer.zero_grad()
        if args.device != torch.device("cpu"):
            with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_losses:
                scaled_losses.backward()
        else:
            loss_dict["loss"].backward()
        optimizer.step()
        scheduler.step()

        arguments["itr"] = itr
        meters.update(**loss_dict)
        itr_time = time.time() - itr_start_time
        itr_start_time = time.time()
        meters.update(itr_time=itr_time)
        if itr % 50 == 0:
            eta_seconds = meters.itr_time.global_avg * (max_itr - itr)
            eta = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                meters.delimiter.join(
                    [
                        "itr: {itr}/{max_itr}",
                        "lr: {lr:.7f}",
                        "{meters}",
                        "eta: {eta}\n",
                    ]
                ).format(
                    itr=itr,
                    lr=optimizer.param_groups[0]["lr"],
                    max_itr=max_itr,
                    meters=str(meters),
                    eta=eta,
                )
            )

        ## save model
        if itr % args.checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(itr), **arguments)
        if itr == max_itr:
            checkpointer.save("model_final", **arguments)
            break

    training_time = time.time() - training_start_time
    training_time = str(datetime.timedelta(seconds=int(training_time)))
    logger.info("total training time: {}".format(training_time))