Ejemplo n.º 1
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model, dump_input))

    if config.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
        # model.load_state_dict(torch.load(config.TEST.MODEL_FILE))
        model.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load(config.TEST.MODEL_FILE)
            ['state_dict'].items()
        })

    else:
        model_state_file = os.path.join(final_output_dir,
                                        'final_state.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    gpus = [0, 1]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    # Data loading code
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    test_dataset = datasets(dataset_root='./Data/',
                            split='test',
                            size=config.MODEL.IMAGE_SIZE[0])

    valid_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(config, valid_loader, model, criterion, final_output_dir,
             tb_log_dir, None)
Ejemplo n.º 2
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
    else:
        model_state_file = os.path.join(
            final_output_dir, 'final_state.pth'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    import pdb; pdb.set_trace()

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True
    )

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)
Ejemplo n.º 3
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=False)

    model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)

    if config.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(config.TEST.MODEL_FILE))
    else:
        model_path = 'model_best.pth.tar' if config.TEST.STATE == 'best' else 'final_state.pth.tar'
        model_state_file = os.path.join(final_output_dir, model_path)
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(config, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)
Ejemplo n.º 4
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        state = torch.load(cfg.TEST.MODEL_FILE)
        if 'best_state_dict' in state.keys():
            state = state['best_state_dict']
        state = model_key_helper(state)
        model.load_state_dict(state)
    else:
        model_state_file = os.path.join(final_output_dir, 'final_state.pth')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(model_key_helper(torch.load(model_state_file)))

    # define loss function (criterion) and optimizer
    matcher = build_matcher(cfg.MODEL.NUM_JOINTS)
    weight_dict = {'loss_ce': 1, 'loss_kpts': cfg.MODEL.EXTRA.KPT_LOSS_COEF}
    criterion = SetCriterion(model.num_classes, matcher, weight_dict,
                             cfg.MODEL.EXTRA.EOS_COEF,
                             ['labels', 'kpts', 'cardinality']).cuda()

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)
Ejemplo n.º 5
0
def validate_hrnet():
    # cudnn related setting
    torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)

    device = torch.device("cuda:0")
    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    print(
        get_model_summary(model.cuda(device),
                          dump_input.cuda(device),
                          verbose=True))

    if config.TEST.MODEL_FILE:
        print('=> loading model from {}'.format(config.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(config.TEST.MODEL_FILE))
    else:
        model_state_file = os.path.join("./", 'final_state.pth.tar')
        print('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    # model = torch.nn.Sequential(model).cuda()
    model = torch.nn.Sequential(model).cuda(device)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda(device)

    # Data loading code
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
            transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.ToTensor(),
            normalize,
        ]))

    print(f'number of images: {len(dataset.imgs)}')

    valid_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(config, valid_loader, model, criterion, './log', '/output', None)
Ejemplo n.º 6
0
def test(model, cfg, final_output_dir, tb_log_dir):
    #model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    #model.eval()
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
    #     cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
    #     transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    #     ])
    # )
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                              criterion, final_output_dir, tb_log_dir)
    return perf_indicator
Ejemplo n.º 7
0
    def eval(self):  # TODO: Pass writer_dict for tensorboard when training
        ckpt_path = self.__restore_model(self.saver_eval, self.sess_eval)
        tf.logging.info('restore from %s' % (ckpt_path))
        print("Starting evaluation process")
        # eval
        nb_iters = int(
            np.ceil(float(self.dataset_eval.num_images) / FLAGS.batch_size))
        eval_rslts = np.zeros((nb_iters, 1))
        all_logits = []
        all_targets = []
        all_ids = []
        for idx_iter in range(nb_iters):
            logits, labels, ids, loss = self.sess_eval.run(self.eval_op)
            eval_rslts[idx_iter] = loss
            all_logits.append([x for x in logits])
            all_targets.append([x for x in labels])
            all_ids.append([x for x in ids])

        name_values, perf_indicator = validate(self.hrnet.cfg,
                                               self.dataset_eval,
                                               outputs=all_logits,
                                               targets=all_targets,
                                               ids=all_ids,
                                               output_dir=self.log_path,
                                               writer_dict=None)

        # TODO: Extend in case of more metrics added beyond loss
        tf.logging.info('%s = %.4e' % ('loss', np.mean(eval_rslts)))
        tf.logging.info('%s = %.4e' % ('AP', perf_indicator))

        return perf_indicator
Ejemplo n.º 8
0
def main():
    args = parse_args()
    update_config(cfg, args)
    #Model loading code
    model = CPNet.get_depth_net(cfg)
    model.load_state_dict(torch.load('./nyuv2.pth'))
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # Data loading code
    valid_loader = NYUlabel13.getTestingData()

    # evaluate on validation set
    loss_indicator = validate(
        valid_loader, model
    )
Ejemplo n.º 9
0
def main():
    if model_name == 'mlp' or model_name == 'knn':
        train_feature_data_list = get_collected_data(val=False)
        train_dataset = DanceDataset(train_feature_data_list)
        val_feature_data_list = get_collected_data(val=True)
        val_dataset = DanceDataset(val_feature_data_list)
        # dataset = HaptDataset(train_x_data_path, train_y_data_path, target_features=target_features)
    elif model_name == 'cnn':
        dataset = HaptRawDataset(raw_json_path)
    # train_length = len(dataset) - val_length
    # train_dataset, val_dataset = random_split(dataset, [train_length, val_length], generator=torch.Generator().manual_seed(7))

    model = None
    if model_name == 'mlp':
        input_size = train_dataset.input_size()
        config = {
            'input_size': input_size,
            'output_size': 12,
            'print_freq': 100
        }
        model = HaptMlpModel(config)
        load_checkpoint(model, './checkpoints', 'best_mlp.pth.tar')
        model = quantize_model(model, train_dataset)
        model = model.to(device)
    if model_name == 'cnn':
        config = {'output_size': 12, 'print_freq': 100}
        model = HaptCnnModel(config)
        load_checkpoint(model, './checkpoints', 'best_cnn.pth.tar')
        model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             pin_memory=True)
    summary = SummaryWriter()
    val_acc, val_f1 = validate(config,
                               val_loader,
                               model,
                               criterion,
                               0,
                               summary,
                               print_output=False,
                               device=device)
    print('Val acc:{}'.format(val_acc))
    print('Val f1:{}'.format(val_f1))
Ejemplo n.º 10
0
def train_mlp(train_dataset, val_dataset):
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True
    input_size = train_dataset.input_size()
    config = {'input_size': input_size, 'output_size': 12, 'print_freq': 100}
    model = HaptMlpModel(config).cuda()

    weight_tensor = [1 for _ in range(config['output_size'] - 2)] + [10, 10]
    criterion = nn.CrossEntropyLoss(torch.Tensor(weight_tensor)).cuda()
    # criterion = nn.CrossEntropyLoss(torch.Tensor([0.1, 0.1, 0.1, 0.1])).cuda()
    optimizer = torch.optim.Adam(
        # model.parameters(),
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)
    summary = SummaryWriter()
    best_val_acc = 0
    best_val_f1 = 0
    for epoch in range(total_epoch):
        train(config, train_loader, model, criterion, optimizer, epoch,
              summary)
        val_acc, val_f1 = validate(config, val_loader, model, criterion, epoch,
                                   summary)
        save_checkpoint(model, epoch, optimizer, './checkpoints',
                        'checkpoint_mlp.pth.tar')
        if val_f1 > best_val_f1:
            save_checkpoint(model, epoch, optimizer, './checkpoints',
                            'best_mlp.pth.tar')
            best_val_acc = val_acc
            best_val_f1 = val_f1
    return best_val_acc, best_val_f1
Ejemplo n.º 11
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_builder = importlib.import_module("models." +
                                            cfg.MODEL.NAME).get_fovea_net
    model = model_builder(cfg, is_train=True)

    # xiaofeng add for load parameter
    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)

    # copy model file -- xiaofeng comment it
    # this_dir = os.path.dirname(__file__)
    # shutil.copy2(os.path.join(this_dir, '../models', cfg.MODEL.NAME + '.py'), final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = HybridLoss(roi_weight=cfg.LOSS.ROI_WEIGHT,
                           regress_weight=cfg.LOSS.REGRESS_WEIGHT,
                           use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT,
                           hrnet_only=cfg.TRAIN.HRNET_ONLY).cuda()

    # Data loading code
    # normalize = transforms.Normalize(
    #     mean=[0.134, 0.207, 0.330], std=[0.127, 0.160, 0.239]
    # )
    # train_dataset = importlib.import_module('dataset.'+cfg.DATASET.DATASET).Dataset(
    #     cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
    #     transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    #     ])
    # )
    # valid_dataset = importlib.import_module('dataset.'+cfg.DATASET.DATASET).Dataset(
    #     cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
    #     transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    #     ])
    # )
    #
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset,
    #     batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
    #     shuffle=cfg.TRAIN.SHUFFLE,
    #     num_workers=cfg.WORKERS,
    #     pin_memory=cfg.PIN_MEMORY
    # )
    # valid_loader = torch.utils.data.DataLoader(
    #     valid_dataset,
    #     batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
    #     shuffle=False,
    #     num_workers=cfg.WORKERS,
    #     pin_memory=cfg.PIN_MEMORY
    # )

    db_trains = []
    db_vals = []
    final_full_test = cfg.TRAIN.FULL_DATA
    normalize_1 = transforms.Normalize(mean=[0.282, 0.168, 0.084],
                                       std=[0.189, 0.110, 0.062])
    train_dataset_1 = importlib.import_module('dataset.' +
                                              cfg.DATASET.DATASET).Dataset(
                                                  cfg, cfg.DATASET.ROOT,
                                                  cfg.DATASET.TRAIN_SET_1,
                                                  True,
                                                  transforms.Compose([
                                                      transforms.ToTensor(),
                                                      normalize_1,
                                                  ]))
    db_trains.append(train_dataset_1)

    normalize_2 = transforms.Normalize(mean=[0.409, 0.270, 0.215],
                                       std=[0.288, 0.203, 0.160])
    train_dataset_2 = importlib.import_module('dataset.' +
                                              cfg.DATASET.DATASET).Dataset(
                                                  cfg, cfg.DATASET.ROOT,
                                                  cfg.DATASET.TRAIN_SET_2,
                                                  True,
                                                  transforms.Compose([
                                                      transforms.ToTensor(),
                                                      normalize_2,
                                                  ]))
    db_trains.append(train_dataset_2)

    if final_full_test is True:
        normalize_3 = transforms.Normalize(mean=[0.404, 0.271, 0.222],
                                           std=[0.284, 0.202, 0.163])
        train_dataset_3 = importlib.import_module(
            'dataset.' + cfg.DATASET.DATASET).Dataset(
                cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, True,
                transforms.Compose([
                    transforms.ToTensor(),
                    normalize_3,
                ]))
        db_trains.append(train_dataset_3)

    train_dataset = ConcatDataset(db_trains)
    logger.info("Combined Dataset: Total {} images".format(len(train_dataset)))

    train_batch_size = cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               shuffle=cfg.TRAIN.SHUFFLE,
                                               num_workers=cfg.WORKERS,
                                               pin_memory=cfg.PIN_MEMORY)

    normalize = transforms.Normalize(mean=[0.404, 0.271, 0.222],
                                     std=[0.284, 0.202, 0.163])
    val_dataset_1 = importlib.import_module('dataset.' +
                                            cfg.DATASET.DATASET).Dataset(
                                                cfg, cfg.DATASET.ROOT,
                                                cfg.DATASET.TEST_SET, False,
                                                transforms.Compose([
                                                    transforms.ToTensor(),
                                                    normalize,
                                                ]))
    db_vals.append(val_dataset_1)

    if final_full_test is True:
        normalize_1 = transforms.Normalize(mean=[0.282, 0.168, 0.084],
                                           std=[0.189, 0.110, 0.062])
        val_dataset_2 = importlib.import_module('dataset.' +
                                                cfg.DATASET.DATASET).Dataset(
                                                    cfg, cfg.DATASET.ROOT,
                                                    cfg.DATASET.TRAIN_SET_1,
                                                    False,
                                                    transforms.Compose([
                                                        transforms.ToTensor(),
                                                        normalize_1,
                                                    ]))
        db_vals.append(val_dataset_2)

        normalize_2 = transforms.Normalize(mean=[0.409, 0.270, 0.215],
                                           std=[0.288, 0.203, 0.160])
        val_dataset_3 = importlib.import_module('dataset.' +
                                                cfg.DATASET.DATASET).Dataset(
                                                    cfg, cfg.DATASET.ROOT,
                                                    cfg.DATASET.TRAIN_SET_2,
                                                    False,
                                                    transforms.Compose([
                                                        transforms.ToTensor(),
                                                        normalize_2,
                                                    ]))
        db_vals.append(val_dataset_3)

    valid_dataset = ConcatDataset(db_vals)

    logger.info("Val Dataset: Total {} images".format(len(valid_dataset)))

    test_batch_size = cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    logger.info("Train len: {}, batch_size: {}; Test len: {}, batch_size: {}" \
                .format(len(train_loader), train_batch_size, len(valid_loader), test_batch_size))

    best_metric = 1e6
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH

    if cfg.TEST.MODEL_FILE:
        checkpoint_file = cfg.TEST.MODEL_FILE
    else:
        checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        # begin_epoch = checkpoint['epoch']
        begin_epoch = 0  # xiaofeng change it
        best_metric = checkpoint['metric']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    if cfg.TRAIN.LR_EXP:
        # llr=lr∗gamma∗∗epoch
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                              cfg.TRAIN.GAMMA1,
                                                              last_epoch=-1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        start_time = timer()

        lr_scheduler.step()

        # evaluate on validation set
        # lr_metric, hr_metric, final_metric = validate(
        #     cfg, valid_loader, valid_dataset, model, criterion,
        #     final_output_dir, tb_log_dir, writer_dict, db_vals
        # )
        # print("validation before training spent time:")
        # timer(start_time)  # timing ends here for "start_time" variable

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        print("epoch %d train spent time:" % (epoch))
        train_time = timer(
            start_time)  # timing ends here for "start_time" variable

        # if epoch >= int(cfg.TRAIN.END_EPOCH/10):
        # evaluate on validation set
        lr_metric, hr_metric, final_metric = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict, db_vals)

        print("validation spent time:")
        val_time = timer(
            train_time)  # timing ends here for "start_time" variable

        min_metric = min(lr_metric, hr_metric, final_metric)
        if min_metric <= best_metric:
            best_metric = min_metric
            best_model = True
            logger.info('=> epoch [{}] best model result: {}'.format(
                epoch, best_metric))
        else:
            best_model = False

        # xiaofeng changed it
        if best_model is True:
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            # transfer the model to CPU before saving to fix unstable bug:
            # github.com/pytorch/pytorch/issues/10577

            model = model.cpu()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'metric': final_metric,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir)
            model = model.cuda()

            print("saving spent time:")
            end_time = timer(
                val_time)  # timing ends here for "start_time" variable
        elif (epoch % 60 == 0) and (epoch != 0):
            logger.info('=> saving epoch {} checkpoint to {}'.format(
                epoch, final_output_dir))
            # transfer the model to CPU before saving to fix unstable bug:
            # github.com/pytorch/pytorch/issues/10577

            time_str = time.strftime('%Y-%m-%d-%H-%M')
            if cfg.TRAIN.HRNET_ONLY:
                checkpoint_filename = 'checkpoint_HRNET_epoch%d_%s.pth' % (
                    epoch, time_str)
            else:
                checkpoint_filename = 'checkpoint_Hybrid_epoch%d_%s.pth' % (
                    epoch, time_str)
            model = model.cpu()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'metric': final_metric,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir, checkpoint_filename)
            model = model.cuda()

    # xiaofeng change
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    if cfg.TRAIN.HRNET_ONLY:
        model_name = 'final_state_HRNET_%s.pth' % (time_str)
    else:
        model_name = 'final_state_Hybrid_%s.pth' % (time_str)

    final_model_state_file = os.path.join(final_output_dir, model_name)
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

    # save a final checkpoint
    model = model.cpu()
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'metric': final_metric,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir, "checkpoint_final_state.pth")
Ejemplo n.º 12
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)

    model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)
    print(model)

    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../../lib/models', config.MODEL + '.py'),
        final_output_dir)
    shutil.copy2(args.cfg, final_output_dir)
    logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }


    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        start_epoch, model, optimizer = load_checkpoint(model, optimizer,
                                                        final_output_dir)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False
    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, writer_dict)

        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': get_model_name(config),
            'state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info('saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 13
0
def main():
    
    # 주요 path 정의
    data_path = './data'
    train_dir = Path(data_path, 'images/train_imgs')
    
    # config 파일을 가져옵니다.
    args = parse_args()
    update_config(cfg, args)

    lr = cfg.TRAIN.LR
    lamb = cfg.LAMB
    test_option = eval(cfg.test_option)
    
    input_w = cfg.MODEL.IMAGE_SIZE[1]
    input_h = cfg.MODEL.IMAGE_SIZE[0]
    
    # 랜덤 요소를 최대한 줄여줌
    RANDOM_SEED = int(cfg.RANDOMSEED)
    np.random.seed(RANDOM_SEED) # cpu vars
    torch.manual_seed(RANDOM_SEED) # cpu  vars
    random.seed(RANDOM_SEED) # Python
    os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED) # Python hash buildin
    torch.backends.cudnn.deterministic = True  #needed
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED) # if use multi-GPU

    
    # log 데이터와 최종 저장위치를 만듭니다.
    logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, f'lr_{str(lr)}', 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)
    
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK

    # annotation 파일을 만듭니다.
    if os.path.isfile(data_path+'/annotations/train_annotation.pkl') == False :
        make_annotations(data_path)
    
    # 쓰려는 모델을 불러옵니다.
    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )
    
    # model의 끝부분 수정 및 초기화 작업을 진행합니다.
    model = initialize_model(model, cfg)
    
    
    # model 파일과 train.py 파일을 copy합니다.
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    
    shutil.copy2(
        os.path.join(this_dir, '../tools', 'train.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }
    
    
    # model을 그래픽카드가 있을 경우 cuda device로 전환합니다.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # loss를 정의합니다.
    criterion = nn.MSELoss().cuda()

    # Data Augumentation을 정의합니다.
    A_transforms = {
        
        'val':
            A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels'])),
        
        'test':
            A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        }
        
    if input_h == input_w :
        
        A_transforms['train'] = A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.OneOf([A.HorizontalFlip(p=1),
                         A.VerticalFlip(p=1),
                         A.Rotate(p=1),
                         A.RandomRotate90(p=1)
                ], p=0.5),
                A.OneOf([A.MotionBlur(p=1),
                         A.GaussNoise(p=1),
                         A.ColorJitter(p=1)
                ], p=0.5),

                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels']))
        
    else :
        A_transforms['train'] = A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.OneOf([A.HorizontalFlip(p=1),
                         A.VerticalFlip(p=1),
                         A.Rotate(p=1),
                ], p=0.5),
                A.OneOf([A.MotionBlur(p=1),
                         A.GaussNoise(p=1)
                         
                ], p=0.5),
                A.OneOf([A.CropAndPad(percent=0.1, p=1),
                         A.CropAndPad(percent=0.2, p=1),
                         A.CropAndPad(percent=0.3, p=1)
                ], p=0.5),

                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels']))
    

    # parameter를 설정합니다.
    batch_size = int(cfg.TRAIN.BATCH_SIZE_PER_GPU)
    test_ratio = float(cfg.TEST_RATIO)
    num_epochs = cfg.TRAIN.END_EPOCH
    
    # earlystopping에 주는 숫자 변수입니다.
    num_earlystop = num_epochs
    
    # torch에서 사용할 dataset을 생성합니다.
    imgs, bbox, class_labels = make_train_data(data_path)

    since = time.time()
    
    """
    # test_option : train, valid로 데이터를 나눌 때 test data를 고려할지 결정합니다.
        * True일 경우 test file을 10% 뺍니다.
        * False일 경우 test file 빼지 않습니다.
    """
    if test_option == True :
        X_train, X_test, y_train, y_test = train_test_split(imgs, bbox, test_size=0.1, random_state=RANDOM_SEED)
        test_dataset = [X_test, y_test]
        with open(final_output_dir+'/test_dataset.pkl', 'wb') as f:
            pickle.dump(test_dataset, f)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=test_ratio, random_state=RANDOM_SEED)
        test_data = Dataset(train_dir, X_test, y_test, data_transforms=A_transforms, class_labels=class_labels, phase='val')
        test_loader = data_utils.DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    else :
        X_train, X_val, y_train, y_val = train_test_split(imgs, bbox, test_size=test_ratio, random_state=RANDOM_SEED)
        
    train_data = Dataset(train_dir, X_train, y_train, data_transforms=A_transforms, class_labels=class_labels, phase='train')
    
    val_data = Dataset(train_dir, X_val, y_val, data_transforms=A_transforms, class_labels=class_labels, phase='val')
    train_loader = data_utils.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = data_utils.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    
    # best loss를 판별하기 위한 변수 초기화
    best_perf = 10000000000
    test_loss = None
    best_model = False
    
    # optimizer 정의
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr
    )
    
    # 중간에 학습된 모델이 있다면 해당 epoch에서부터 진행할 수 있도록 만듭니다.
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )
    
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        num_epochs = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))
    
    # lr_scheduler 정의
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=-1
    )
    
    # early stopping하는데 사용하는 count 변수
    count = 0
    val_losses = []
    train_losses = []
    
    # 학습 시작
    for epoch in range(begin_epoch, num_epochs):
        epoch_since = time.time()
        
        lr_scheduler.step()
        
        # train for one epoch
        train_loss = train(cfg, device, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict, lamb=lamb)

        
        # evaluate on validation set
        perf_indicator = validate(
            cfg, device, val_loader, val_data, model, criterion,
            final_output_dir, tb_log_dir, writer_dict, lamb=lamb
        )
        
        # 해당 epoch이 best_model인지 판별합니다. valid 값을 기준으로 결정됩니다.
        if perf_indicator <= best_perf:
            best_perf = perf_indicator
            best_model = True
            count = 0
            
        else:
            best_model = False
            count +=1
            
        
        
        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)
        
        # loss를 저장합니다.
        val_losses.append(perf_indicator)
        train_losses.append(train_loss)
        if count == num_earlystop :
            break
        
        
        epoch_time_elapsed = time.time() - epoch_since
        print(f'epoch : {epoch}' \
                f' train loss : {round(train_loss,3)}' \
                              f' valid loss : {round(perf_indicator,3)}' \
                              f' Elapsed time: {int(epoch_time_elapsed // 60)}m {int(epoch_time_elapsed % 60)}s')
        
    # log 파일 등을 저장합니다.
    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

    time_elapsed = time.time() - since
    print('Training and Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best validation loss: {:4f}\n'.format(best_perf))
    
    # test_option이 True일 경우, 떼어난 10% 데이터에 대해 만들어진 모델로 eval을 진행합니다.
    if test_option == True :
        # test data
        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
            cfg, is_train=True)
        
        model = initialize_model(model, cfg)
        parameters = f'{final_output_dir}/model_best.pth'
        
        model = model.to(device)
        model.load_state_dict(torch.load(parameters))
        
        test_loss = validate(
                cfg, device, test_loader, test_data, model, criterion,
                final_output_dir, tb_log_dir, writer_dict, lamb=lamb
            )
    
    print(f'test loss : {test_loss}')
    
    # loss 결과를 pickle 파일로 따로 저장합니다.
    result_dict = {}
    result_dict['val_loss'] = val_losses
    result_dict['train_loss'] = train_losses
    result_dict['best_loss'] = best_perf
    result_dict['test_loss'] = test_loss
    result_dict['lr'] = lr
    with open(final_output_dir+'/result.pkl', 'wb') as f:
        pickle.dump(result_dict, f)
Ejemplo n.º 14
0
def main():
    args = parse_args()
    update_config(cfg, args)

    if args.prevModelDir and args.modelDir:
        # copy pre models for philly
        copy_prev_models(args.prevModelDir, args.modelDir)

    logger, final_output_dir, tb_log_dir = create_logger(cfg,
                                                         args.cfg,
                                                         'valid',
                                                         dry=True)

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)
    epoch = 0
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS)

    if cfg.TEST.MODEL_FILE == 'none':
        t = {}
        logger.info('=> Not reloading any model')
    elif cfg.TEST.MODEL_FILE:
        logger.info('=> loading TEST.MODEL_FILE model from {}'.format(
            cfg.TEST.MODEL_FILE))
        t = torch.load(cfg.TEST.MODEL_FILE)
        model.load_state_dict(t, strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth')
        logger.info('=> loading model from {}'.format(model_state_file))
        t = torch.load(model_state_file)
        if 'state_dict' in t:
            begin_epoch = t['epoch']
            best_perf = t['perf']
            last_epoch = t['epoch']
            model.load_state_dict(t['state_dict'])
        else:
            model.load_state_dict(t)

    if 'epoch' in t:
        epoch = t['epoch']
        logger.info('Reloaded epoch', epoch)
    else:
        logger.info('No epoch in model, setting to last epoch')
        epoch = cfg.TRAIN.END_EPOCH

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir, epoch)
Ejemplo n.º 15
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=False)

    base_model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)

    model_dict = {}
    model_dict['base_model'] = base_model
    config.LOSS.USE_GLOBAL_MI_LOSS = False
    config.LOSS.USE_LOCAL_MI_LOSS = False
    config.LOSS.USE_FUNDAMENTAL_LOSS = False
    # if config.LOSS.USE_GLOBAL_MI_LOSS:
    #     global_discriminator = models.discriminator.GlobalDiscriminator(config)
    #     model_dict['global_discriminator'] = global_discriminator
    # if config.LOSS.USE_LOCAL_MI_LOSS:
    #     local_discriminator = models.discriminator.LocalDiscriminator(config)
    #     model_dict['local_discriminator'] = local_discriminator

    if config.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
        state_dict = torch.load(config.TEST.MODEL_FILE)
    else:
        model_path = 'model_best.pth.tar' if config.TEST.STATE == 'best' else 'final_state.pth.tar'
        model_state_file = os.path.join(final_output_dir, model_path)
        logger.info('=> loading model from {}'.format(model_state_file))
        state_dict = torch.load(model_state_file)
    if 'state_dict_base_model' in state_dict:
        logger.info('=> new loading mode')
        for key, model in model_dict.items():
            # delete params of the aggregation layer
            if key == 'base_model' and not config.NETWORK.AGGRE:
                for param_key in list(state_dict['state_dict_base_model'].keys()):
                    if 'aggre_layer' in param_key:
                        state_dict['state_dict_base_model'].pop(param_key)
            model_dict[key].load_state_dict(state_dict['state_dict_' + key])
    else:
        logger.info('=> old loading mode')
        # delete params of the aggregation layer
        if not config.NETWORK.AGGRE:
            for param_key in list(state_dict.keys()):
                if 'aggre_layer' in param_key:
                    state_dict.pop(param_key)
        model_dict['base_model'].load_state_dict(state_dict)

    gpus = [int(i) for i in config.GPUS.split(',')]
    for key, model in model_dict.items():
        model_dict[key] = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion_dict = {}
    criterion_dict['mse_weights'] = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()
    criterion_dict['mse'] = torch.nn.MSELoss(reduction='mean').cuda()

    # if config.LOSS.USE_FUNDAMENTAL_LOSS:
    #     criterion_dict['fundamental'] = FundamentalLoss(config)

    # if config.LOSS.USE_GLOBAL_MI_LOSS or config.LOSS.USE_LOCAL_MI_LOSS:
    #     criterion_dict['mutual_info'] = MILoss(config, model_dict)

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,  # training set, is_trainin=True
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        '',
        config.DATASET.NO_DISTORTION)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # evaluate on validation set
    validate(config, valid_loader, valid_dataset, model_dict, criterion_dict,
             final_output_dir, None, rank=0)
Ejemplo n.º 16
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    # print code version info
    repo = Repo('')
    repo_git = repo.git
    working_tree_diff_head = repo_git.diff('HEAD')
    this_commit_hash = repo.commit()
    cur_branches = repo_git.branch('--list')
    logger.info('Current Code Version is {}'.format(this_commit_hash))
    logger.info('Current Branch Info :\n{}'.format(cur_branches))
    logger.info(
        'Working Tree diff with HEAD: \n{}'.format(working_tree_diff_head))

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)
    model = models.multiview_pose_net.get_multiview_pose_net(
        backbone_model, config)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # dump_input = torch.rand(
    #     (config.TRAIN.BATCH_SIZE, 3,  # config.NETWORK.NUM_JOINTS,
    #      config.NETWORK.IMAGE_SIZE[1], config.NETWORK.IMAGE_SIZE[0]))
    # writer_dict['writer'].add_graph(model, dump_input)

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()
    # criterion_fuse = JointsMSELoss(use_target_weight=True).cuda()

    optimizer = get_optimizer(config, model)
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        start_epoch, model, optimizer, ckpt_perf = load_checkpoint(
            model, optimizer, final_output_dir)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        collate_fn=totalcapture_collate,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        collate_fn=totalcapture_collate,
        pin_memory=True)

    best_perf = ckpt_perf
    best_epoch = -1
    best_model = False
    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        extra_param = dict()
        # extra_param['loss2'] = criterion_fuse
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, writer_dict, **extra_param)

        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, writer_dict,
                                  **extra_param)

        logger.info(
            '=> perf indicator at epoch {} is {}. old best is {} '.format(
                epoch, perf_indicator, best_perf))

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
            best_epoch = epoch
            logger.info(
                '====> find new best model at end of epoch {}. (start from 0)'.
                format(epoch))
        else:
            best_model = False
        logger.info(
            'epoch of best validation results is {}'.format(best_epoch))

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

        # save final state at every epoch
        final_model_state_file = os.path.join(
            final_output_dir, 'final_state_ep{}.pth.tar'.format(epoch))
        logger.info(
            'saving final model state to {}'.format(final_model_state_file))
        torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 17
0
def main_worker(rank, args, config, num_gpus):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend='nccl', rank=rank, world_size=num_gpus)
    print('Rank: {} finished initializing, PID: {}'.format(rank, os.getpid()))

    if rank == 0:
        logger, final_output_dir, tb_log_dir = create_logger(
            config, args.cfg, 'train')
        logger.info(pprint.pformat(args))
        logger.info(pprint.pformat(config))
    else:
        final_output_dir = None
        tb_log_dir = None

    # Gracefully kill all subprocesses by command <'kill subprocess 0'>
    signal.signal(signal.SIGTERM, signal_handler)
    if rank == 0:
        logger.info('Rank {} has registerred signal handler'.format(rank))

    # device in current process
    device = torch.device('cuda', rank)

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)
    base_model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)

    model_dict = OrderedDict()
    model_dict['base_model'] = base_model.to(device)

    if config.LOSS.USE_GLOBAL_MI_LOSS:
        global_discriminator = models.discriminator.GlobalDiscriminator(config)
        model_dict['global_discriminator'] = global_discriminator.to(device)
    if config.LOSS.USE_LOCAL_MI_LOSS:
        local_discriminator = models.discriminator.LocalDiscriminator(config)
        model_dict['local_discriminator'] = local_discriminator.to(device)
    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        domain_discriminator = models.discriminator.DomainDiscriminator(config)
        model_dict['domain_discriminator'] = domain_discriminator.to(device)
    if config.LOSS.USE_VIEW_MI_LOSS:
        view_discriminator = models.discriminator.ViewDiscriminator(config)
        model_dict['view_discriminator'] = view_discriminator.to(device)
    if config.LOSS.USE_JOINTS_MI_LOSS:
        joints_discriminator = models.discriminator.JointsDiscriminator(config)
        model_dict['joints_discriminator'] = joints_discriminator.to(device)
    if config.LOSS.USE_HEATMAP_MI_LOSS:
        heatmap_discriminator = models.discriminator.HeatmapDiscriminator(config)
        model_dict['heatmap_discriminator'] = heatmap_discriminator.to(device)

    # copy model files and print model config
    if rank == 0:
        this_dir = os.path.dirname(__file__)
        shutil.copy2(
            os.path.join(this_dir, '../../lib/models', config.MODEL + '.py'),
            final_output_dir)
        shutil.copy2(args.cfg, final_output_dir)
        logger.info(pprint.pformat(model_dict['base_model']))
        if config.LOSS.USE_GLOBAL_MI_LOSS:
            logger.info(pprint.pformat(model_dict['global_discriminator']))
        if config.LOSS.USE_LOCAL_MI_LOSS:
            logger.info(pprint.pformat(model_dict['local_discriminator']))
        if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
            logger.info(pprint.pformat(model_dict['domain_discriminator']))
        if config.LOSS.USE_VIEW_MI_LOSS:
            logger.info(pprint.pformat(model_dict['view_discriminator']))
        if config.LOSS.USE_JOINTS_MI_LOSS:
            logger.info(pprint.pformat(model_dict['joints_discriminator']))
        if config.LOSS.USE_HEATMAP_MI_LOSS:
            logger.info(pprint.pformat(model_dict['heatmap_discriminator']))
        if config.LOSS.USE_GLOBAL_MI_LOSS or config.LOSS.USE_LOCAL_MI_LOSS \
            or config.LOSS.USE_DOMAIN_TRANSFER_LOSS or config.LOSS.USE_VIEW_MI_LOSS \
            or config.LOSS.USE_JOINTS_MI_LOSS or config.LOSS.USE_HEATMAP_MI_LOSS:
            shutil.copy2(
                os.path.join(this_dir, '../../lib/models', 'discriminator.py'),
                final_output_dir)

    # tensorboard writer
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    } if rank == 0 else None

    # dump_input = torch.rand(
    #     (config.TRAIN.BATCH_SIZE, 3,
    #      config.NETWORK.IMAGE_SIZE[1], config.NETWORK.IMAGE_SIZE[0]))
    # writer_dict['writer'].add_graph(model, (dump_input,))

    # first resume, then parallel
    for key in model_dict.keys():
        model_dict[key] = torch.nn.parallel.DistributedDataParallel(model_dict[key], device_ids=[rank], output_device=rank)
        # one by one
        dist.barrier()

    # get optimizer
    optimizer_dict = {}
    optimizer_base_model = get_optimizer(config, model_dict['base_model'])
    optimizer_dict['base_model'] = optimizer_base_model
    if config.LOSS.USE_GLOBAL_MI_LOSS:
        optimizer_global = get_optimizer(config, model_dict['global_discriminator'], is_discriminator=True)
        optimizer_dict['global_discriminator'] = optimizer_global
    if config.LOSS.USE_LOCAL_MI_LOSS:
        optimizer_local = get_optimizer(config, model_dict['local_discriminator'], is_discriminator=True)
        optimizer_dict['local_discriminator'] = optimizer_local
    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        optimizer_domain = get_optimizer(config, model_dict['domain_discriminator'], is_discriminator=True)
        optimizer_dict['domain_discriminator'] = optimizer_domain
    if config.LOSS.USE_VIEW_MI_LOSS:
        optimizer_view = get_optimizer(config, model_dict['view_discriminator'], is_discriminator=True)
        optimizer_dict['view_discriminator'] = optimizer_view
    if config.LOSS.USE_JOINTS_MI_LOSS:
        optimizer_joints = get_optimizer(config, model_dict['joints_discriminator'], is_discriminator=True)
        optimizer_dict['joints_discriminator'] = optimizer_joints
    if config.LOSS.USE_HEATMAP_MI_LOSS:
        optimizer_heatmap = get_optimizer(config, model_dict['heatmap_discriminator'], is_discriminator=True)
        optimizer_dict['heatmap_discriminator'] = optimizer_heatmap

    # resume
    if config.TRAIN.RESUME:
        assert config.TRAIN.RESUME_PATH != '', 'You must designate a path for config.TRAIN.RESUME_PATH, rank: {}'.format(rank)
        if rank == 0:
            logger.info('=> loading model from {}'.format(config.TRAIN.RESUME_PATH))
        # !!! map_location must be cpu, otherwise a lot memory will be allocated on gpu:0.
        state_dict = torch.load(config.TRAIN.RESUME_PATH, map_location=torch.device('cpu'))
        if 'state_dict_base_model' in state_dict:
            if rank == 0:
                logger.info('=> new loading mode')
            for key in model_dict.keys():
                # delete params of the aggregation layer
                if key == 'base_model' and not config.NETWORK.AGGRE:
                    for param_key in list(state_dict['state_dict_base_model'].keys()):
                        if 'aggre_layer' in param_key:
                            state_dict['state_dict_base_model'].pop(param_key)
                model_dict[key].module.load_state_dict(state_dict['state_dict_' + key])
        else:
            if rank == 0:
                logger.info('=> old loading mode')
            # delete params of the aggregation layer
            if not config.NETWORK.AGGRE:
                for param_key in list(state_dict.keys()):
                    if 'aggre_layer' in param_key:
                        state_dict.pop(param_key)
            model_dict['base_model'].module.load_state_dict(state_dict)

    # Traing on server cluster, resumed when interrupted
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.ON_SERVER_CLUSTER:
        start_epoch, model_dict, optimizer_dict, loaded_iteration = load_checkpoint(model_dict, optimizer_dict,
                                                        final_output_dir)
        if args.iteration < loaded_iteration:
            # this training process shold be skipped
            if rank == 0:
                logger.info('=> Skipping training iteration #{}'.format(args.iteration))
            return

    # lr schedulers have different starting points yet share same decay strategy.
    lr_scheduler_dict = {}
    for key in optimizer_dict.keys():
        lr_scheduler_dict[key] = torch.optim.lr_scheduler.MultiStepLR(
            optimizer_dict[key], config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # torch.set_num_threads(8)

    criterion_dict = {}
    criterion_dict['mse_weights'] = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).to(device)
    criterion_dict['mse'] = torch.nn.MSELoss(reduction='mean').to(device)

    if config.LOSS.USE_FUNDAMENTAL_LOSS:
        criterion_dict['fundamental'] = FundamentalLoss(config)

    if config.LOSS.USE_GLOBAL_MI_LOSS or config.LOSS.USE_LOCAL_MI_LOSS:
        criterion_dict['mutual_info'] = MILoss(config, model_dict)

    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        criterion_dict['bce'] = torch.nn.BCELoss().to(device)

    if config.LOSS.USE_VIEW_MI_LOSS:
        criterion_dict['view_mi'] = ViewMILoss(config, model_dict)

    if config.LOSS.USE_JOINTS_MI_LOSS:
        criterion_dict['joints_mi'] = JointsMILoss(config, model_dict)

    if config.LOSS.USE_HEATMAP_MI_LOSS:
        criterion_dict['heatmap_mi'] = HeatmapMILoss(config, model_dict)

    # Data loading code
    if rank == 0:
        logger.info('=> loading dataset')
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        config.DATASET.PSEUDO_LABEL_PATH,
        config.DATASET.NO_DISTORTION)
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        '',
        config.DATASET.NO_DISTORTION)
    # Debug ##################
    # print('len of mixed dataset:', len(train_dataset))
    # print('len of multiview h36m dataset:', len(valid_dataset))

    train_loader, train_sampler = get_training_loader(train_dataset, config)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE,  # no need to multiply len(gpus)
        shuffle=False,
        num_workers=int(config.WORKERS / num_gpus),
        pin_memory=False)

    best_perf = 0
    best_model = False

    dist.barrier()

    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        for lr_scheduler in lr_scheduler_dict.values():
            lr_scheduler.step()

        train_sampler.set_epoch(epoch)

        train(config, train_loader, model_dict, criterion_dict, optimizer_dict, epoch,
                final_output_dir, writer_dict, rank)
        perf_indicator = validate(config, valid_loader, valid_dataset, model_dict,
                                  criterion_dict, final_output_dir, writer_dict, rank)

        if rank == 0:
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

            logger.info('=> saving checkpoint to {}'.format(final_output_dir))

            save_dict = {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'perf': perf_indicator,
                'iteration': args.iteration
            }
            model_state_dict = {}
            optimizer_state_dict = {}
            for key, model in model_dict.items():
                model_state_dict['state_dict_' + key] = model.module.state_dict()
                optimizer_state_dict['optimizer_' + key] = optimizer_dict[key].state_dict()
            save_dict.update(model_state_dict)
            save_dict.update(optimizer_state_dict)
            save_checkpoint(save_dict, best_model, final_output_dir)
        dist.barrier()

    if rank == 0:
        final_model_state_file = os.path.join(final_output_dir,
                                              'final_state.pth.tar')
        logger.info('saving final model state to {}'.format(final_model_state_file))
        torch.save(model_state_dict, final_model_state_file)
        writer_dict['writer'].close()

    print('Rank {} exit'.format(rank))
Ejemplo n.º 18
0
def main():
    args = parse_args()

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_nnb')(config)

    writer_dict = {
        'writer': SummaryWriter(log_dir='./output/facexray'),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = Loss()

    optimizer = get_optimizer(config, model)

    last_epoch = config.TRAIN.BEGIN_EPOCH

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    # list_name没有单独标注在.yaml文件
    # transform还没能适用于其他规格,应做成[256, 256, 3]
    train_dataset = eval('dataset.' + config.DATASET.DATASET + '.' +
                         config.DATASET.DATASET)(
                             config.DATASET.ROOT, config.DATASET.TRAIN_SET,
                             None, transforms.Compose([transforms.ToTensor()]))

    valid_dataset = eval('dataset.' + config.DATASET.DATASET + '.' +
                         config.DATASET.DATASET)(config.DATASET.ROOT,
                                                 config.DATASET.TEST_SET, None,
                                                 transforms.Compose(
                                                     [transforms.ToTensor()]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # 前50000次迭代锁定原hrnet层参数训练,后面的迭代训练所有参数
        if epoch == 150000:
            for k, v in model.named_parameters():
                v.requires_grad = True

        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              writer_dict)
        # evaluate on validation set
        validate(config, valid_loader, model, criterion, writer_dict)

    torch.save(model.module.state_dict(), './output/BI_dataset/faceXray.pth')
    writer_dict['writer'].close()
Ejemplo n.º 19
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    if args.local_rank == 0:
        # provide the summary of model
        dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
        logger.info(get_model_summary(model.to(device), dump_input.to(device)))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # prepare data
    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
        scale_factor=config.TRAIN.SCALE_FACTOR)

    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler)

    if config.DATASET.EXTRA_TRAIN_SET:
        extra_train_dataset = eval('datasets.' + config.DATASET.DATASET)(
            root=config.DATASET.ROOT,
            list_path=config.DATASET.EXTRA_TRAIN_SET,
            num_samples=None,
            num_classes=config.DATASET.NUM_CLASSES,
            multi_scale=config.TRAIN.MULTI_SCALE,
            flip=config.TRAIN.FLIP,
            ignore_label=config.TRAIN.IGNORE_LABEL,
            base_size=config.TRAIN.BASE_SIZE,
            crop_size=crop_size,
            downsample_rate=config.TRAIN.DOWNSAMPLERATE,
            scale_factor=config.TRAIN.SCALE_FACTOR)

        if distributed:
            extra_train_sampler = DistributedSampler(extra_train_dataset)
        else:
            extra_train_sampler = None

        extra_trainloader = torch.utils.data.DataLoader(
            extra_train_dataset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
            num_workers=config.WORKERS,
            pin_memory=True,
            drop_last=True,
            sampler=extra_train_sampler)

    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        center_crop_test=config.TEST.CENTER_CROP_TEST,
        downsample_rate=1)

    if distributed:
        test_sampler = DistributedSampler(test_dataset)
    else:
        test_sampler = None

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model = FullModel(model, criterion)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        optimizer = torch.optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 model.parameters()),
                'lr': config.TRAIN.LR
            }],
            lr=config.TRAIN.LR,
            momentum=config.TRAIN.MOMENTUM,
            weight_decay=config.TRAIN.WD,
            nesterov=config.TRAIN.NESTEROV,
        )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=lambda storage, loc: storage)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters

    for epoch in range(last_epoch, end_epoch):
        if distributed:
            train_sampler.set_epoch(epoch)
        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch - config.TRAIN.END_EPOCH,
                  config.TRAIN.EXTRA_EPOCH, epoch_iters, config.TRAIN.EXTRA_LR,
                  extra_iters, extra_trainloader, optimizer, model,
                  writer_dict, device)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, epoch_iters,
                  config.TRAIN.LR, num_iters, trainloader, optimizer, model,
                  writer_dict, device)

        valid_loss, mean_IoU, IoU_array = validate(config, testloader, model,
                                                   writer_dict, device)

        if args.local_rank == 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    'checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))

            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                valid_loss, mean_IoU, best_mIoU)
            logging.info(msg)
            logging.info(IoU_array)

            if epoch == end_epoch - 1:
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'final_state.pth'))

                writer_dict['writer'].close()
                end = timeit.default_timer()
                logger.info('Hours: %d' % np.int((end - start) / 3600))
                logger.info('Done')
Ejemplo n.º 20
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, "train")

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval("models." + config.MODEL.NAME + ".get_pose_net")(
        config, is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, "../lib/models", config.MODEL.NAME + ".py"),
        final_output_dir,
    )

    writer_dict = {
        "writer": SummaryWriter(log_dir=tb_log_dir),
        "train_global_steps": 0,
        "valid_global_steps": 0,
    }

    dump_input = torch.rand((
        config.TRAIN.BATCH_SIZE,
        3,
        config.MODEL.IMAGE_SIZE[1],
        config.MODEL.IMAGE_SIZE[0],
    ))
    writer_dict["writer"].add_graph(model, (dump_input, ), verbose=False)

    gpus = [int(i) for i in config.GPUS.split(",")]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval("dataset." + config.DATASET.DATASET)(
        config,
        config.DATASET.ROOT,
        config.DATASET.TRAIN_SET,
        True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    )
    valid_dataset = eval("dataset." + config.DATASET.DATASET)(
        config,
        config.DATASET.ROOT,
        config.DATASET.TEST_SET,
        False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
    )

    best_perf = 0.0
    best_model = False
    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(
            config,
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            final_output_dir,
            tb_log_dir,
            writer_dict,
        )

        # evaluate on validation set
        perf_indicator = validate(
            config,
            valid_loader,
            valid_dataset,
            model,
            criterion,
            final_output_dir,
            tb_log_dir,
            writer_dict,
        )

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info("=> saving checkpoint to {}".format(final_output_dir))
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "model": get_model_name(config),
                "state_dict": model.state_dict(),
                "perf": perf_indicator,
                "optimizer": optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
        )

    final_model_state_file = os.path.join(final_output_dir,
                                          "final_state.pth.tar")
    logger.info(
        "saving final model state to {}".format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict["writer"].close()
Ejemplo n.º 21
0
def main_worker(gpu, ngpus_per_node, args, final_output_dir, tb_log_dir):

    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('Init process group: dist_url: {}, world_size: {}, rank: {}'.format(cfg.DIST_URL, args.world_size, args.rank))
    dist.init_process_group(backend=cfg.DIST_BACKEND, init_method=cfg.DIST_URL, world_size=args.world_size, rank=args.rank)

    update_config(cfg, args)

    # setup logger
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=True)
    logger.info(get_model_summary(model, torch.zeros(1, 3, *cfg.MODEL.IMAGE_SIZE)))

    # copy model file
    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED and args.rank % ngpus_per_node == 0):
        this_dir = os.path.dirname(__file__)
        shutil.copy2(os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED and args.rank % ngpus_per_node == 0):
        dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
        writer_dict['writer'].add_graph(model, (dump_input, ))
        # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.MODEL.SYNC_BN:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda(args.gpu)

    # Data loading code
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=(train_sampler is None),
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
        sampler=train_sampler
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    logger.info(train_loader.dataset)

    best_perf = -1
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        
        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`.
        lr_scheduler.step()

        # evaluate on validation set
        perf_indicator = validate(
            args, cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict
        )

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        if not cfg.MULTIPROCESSING_DISTRIBUTED or (
                cfg.MULTIPROCESSING_DISTRIBUTED
                and args.rank == 0
        ):
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state{}.pth.tar'.format(gpu)
    )

    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
    else:
        model_state_file = os.path.join(
            final_output_dir, 'final_state.pth'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))
    model.eval()
    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    print(cfg.DATASET.DATASET)
    print(cfg.DATASET.ROOT)
    print(cfg.DATASET.TEST_SET)
    img_sets = img_coco.IMGCOCO(cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,transforms.Compose([transforms.ToTensor(), normalize,]))
    all_imgids = img_sets.image_set
    with torch.no_grad():
        for idx, imid in enumerate(all_imgids):
            #if idx >= 20:
            #    break
            persons, all_bbs, all_scales, ori_img, imname = img_sets.generate_pose_input(imid)
            all_pts = []
            for pid, person in enumerate(persons):
                outputs = model(person)
                #print(outputs.numpy().shape)
                preds, maxvals = get_final_preds(cfg, outputs.clone().cpu().numpy(), [],[])
                kpts = preds[0,:] * 4
                all_pts.append(kpts)
                #print(kpts)
                #print(kpts.astype(np.int32))
                #draw_kpts(ori_persons[pid], kpts)
                #cv2.imshow('people', person)
                #cv2.waitKey()
            vis_img = draw_kpts(ori_img,all_bbs, all_pts, all_scales)
            out_path = os.path.join('results', imname)
            cv2.imwrite(out_path, vis_img)
            
    
    return
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([transforms.ToTensor(),normalize,]))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    return
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=True
    )

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, model, criterion,
             final_output_dir, tb_log_dir)
Ejemplo n.º 23
0
def main():
    # 对输入参数进行解析
    args = parse_args()
    # 根据输入参数对cfg进行更新
    update_config(cfg, args)

    # 创建logger,用于记录训练过程的打印信息
    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    # 使用GPU的一些相关设置
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # 根据配置文件构建网络
    # 两个模型:models.pose_hrnet和models.pose_resnet,用get_pose_net这个函数可以获得网络结构
    print('models.' + cfg.MODEL.NAME + '.get_pose_net')
    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    # 拷贝lib/models/pose_hrnet.py文件到输出目录之中
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    # 用于训练信息的图形化显示
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # 用于模型的图形化显示
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    # 让模型支持多GPU训练
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    # 用于计算loss
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    # 对输入图像数据进行正则化处理
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # 创建训练以及测试数据的迭代器
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    # 模型加载以及优化策略的相关配置
    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # 循环迭代进行训练
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    # 模型保存
    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 24
0
def main():

    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    #specify which gpu to use
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'train')

    # model = torchvision.models.resnet18(pretrained=config.MODEL.PRETRAINED)
    # num_ftrs = model.fc.in_features
    # model.fc = nn.Sequential(
    #     nn.Dropout(0.5),
    #     nn.Linear(num_ftrs, config.MODEL.OUTPUT_SIZE[0]))

    model = ResModel(config) 

    

    # copy model files
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
    # loss
    pos_weight = torch.tensor([2.6, 3.4, 3.0, 1.2, 1.1, 1.0, 1.1, 1.2, 3.4, 1.7, 3.6, 3.8], dtype=torch.float32)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).cuda()  #
    criterion_val = torch.nn.BCEWithLogitsLoss().cuda()

    optimizer = utils.get_optimizer(config, model)
   
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'latest.pth')
        if os.path.islink(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})"
                  .format(checkpoint['epoch']))
        else:
            print("=> no checkpoint found")

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    dataset_type = get_dataset(config)

    train_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=True),
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    val_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=False),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY
    )

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        function.train(config, train_loader, model, criterion,
                       optimizer, epoch, writer_dict)

        # evaluate
        predictions = function.validate(config, val_loader, model,
                                        criterion_val, epoch, writer_dict)

        if epoch % 5 == 0:
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            torch.save(model.module.state_dict(), os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch)))
        # utils.save_checkpoint(
        #     {"state_dict": model,
        #      "epoch": epoch + 1,
        #      "optimizer": optimizer.state_dict(),
        #      }, predictions, final_output_dir, 'checkpoint_{}.pth'.format(epoch))

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    t_checkpoints = cfg.KD.TEACHER  #注意是在student配置文件中修改
    train_type = cfg.KD.TRAIN_TYPE  #注意是在student配置文件中修改
    train_type = get_train_type(train_type, t_checkpoints)
    logger.info('=> train type is {} '.format(train_type))

    if train_type == 'FPD':
        cfg_name = 'student_' + os.path.basename(args.cfg).split('.')[0]
    else:
        cfg_name = os.path.basename(args.cfg).split('.')[0]
    save_yaml_file(cfg_name, cfg, final_output_dir)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # fpd method, default NORMAL
    if train_type == 'FPD':
        tcfg = cfg.clone()
        tcfg.defrost()
        tcfg.merge_from_file(args.tcfg)
        tcfg.freeze()
        tcfg_name = 'teacher_' + os.path.basename(args.tcfg).split('.')[0]
        save_yaml_file(tcfg_name, tcfg, final_output_dir)
        # teacher model
        tmodel = eval('models.' + tcfg.MODEL.NAME + '.get_pose_net')(
            tcfg, is_train=False)

        load_checkpoint(t_checkpoints,
                        tmodel,
                        strict=True,
                        model_info='teacher_' + tcfg.MODEL.NAME)

        tmodel = torch.nn.DataParallel(tmodel, device_ids=cfg.GPUS).cuda()
        # define kd_pose loss function (criterion) and optimizer
        kd_pose_criterion = JointsMSELoss(
            use_target_weight=tcfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    if cfg.TRAIN.CHECKPOINT:
        load_checkpoint(cfg.TRAIN.CHECKPOINT,
                        model,
                        strict=True,
                        model_info='student_' + cfg.MODEL.NAME)
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # you can choose or replace pose_loss and kd_pose_loss type, including mse,kl,ohkm loss ect
    # define pose loss function (criterion) and optimizer
    pose_criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, tmodel, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)
    validate(cfg, valid_loader, valid_dataset, model, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # fpd method, default NORMAL
        if train_type == 'FPD':
            # train for one epoch
            fpd_train(cfg, train_loader, model, tmodel, pose_criterion,
                      kd_pose_criterion, optimizer, epoch, final_output_dir,
                      tb_log_dir, writer_dict)
        else:
            # train for one epoch
            train(cfg, train_loader, model, pose_criterion, optimizer, epoch,
                  final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  pose_criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 26
0
def main():
    args = parse_args()

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model, dump_input))

    # copy model file
    # this_dir = os.path.dirname(__file__)
    # models_dst_dir = os.path.join(final_output_dir, 'models')
    # if os.path.exists(models_dst_dir):
    #     shutil.rmtree(models_dst_dir)
    # shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    '''
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
    '''
    # Change DP to DDP
    torch.cuda.set_device(args.local_rank)
    model = model.to(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
            best_model = True

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    '''
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
    '''
    # Change to TSV dataset instance
    train_dataset = TSVInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    # DDP requires DistributedSampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=(train_sampler is None),
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        TSVInstance(
            valdir,
            transforms.Compose([
                transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
                transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
            filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 27
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    args = parse_args()
    print('out')
    print(args)

    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
        config, is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (config.TRAIN.BATCH_SIZE, 3, config.MODEL.IMAGE_SIZE[1],
         config.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False

    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        #print("model check!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        #for i,p in enumerate(model.parameters()):
        #    print(p.requires_grad)

        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 28
0
def main():
    args = parse_args()

    if args.seed > 0:
        import random
        print('Seeding with', args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = args.local_rank >= 0
    if distributed:
        device = torch.device('cuda:{}'.format(args.local_rank))
        print(device)
        torch.cuda.set_device(device)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    # dump_input = torch.rand(
    #     (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    # )
    # logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    # copy model file
    if distributed and args.local_rank == 0:
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        batch_size = config.TRAIN.BATCH_SIZE_PER_GPU
    else:
        batch_size = config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus)

    # prepare data
    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
        scale_factor=config.TRAIN.SCALE_FACTOR)

    train_sampler = get_sampler(train_dataset)
    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              shuffle=config.TRAIN.SHUFFLE
                                              and train_sampler is None,
                                              num_workers=config.WORKERS,
                                              pin_memory=True,
                                              drop_last=True,
                                              sampler=train_sampler)

    extra_epoch_iters = 0
    if config.DATASET.EXTRA_TRAIN_SET:
        extra_train_dataset = eval('datasets.' + config.DATASET.DATASET)(
            root=config.DATASET.ROOT,
            list_path=config.DATASET.EXTRA_TRAIN_SET,
            num_samples=None,
            num_classes=config.DATASET.NUM_CLASSES,
            multi_scale=config.TRAIN.MULTI_SCALE,
            flip=config.TRAIN.FLIP,
            ignore_label=config.TRAIN.IGNORE_LABEL,
            base_size=config.TRAIN.BASE_SIZE,
            crop_size=crop_size,
            downsample_rate=config.TRAIN.DOWNSAMPLERATE,
            scale_factor=config.TRAIN.SCALE_FACTOR)
        extra_train_sampler = get_sampler(extra_train_dataset)
        extra_trainloader = torch.utils.data.DataLoader(
            extra_train_dataset,
            batch_size=batch_size,
            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
            num_workers=config.WORKERS,
            pin_memory=True,
            drop_last=True,
            sampler=extra_train_sampler)
        extra_epoch_iters = np.int(extra_train_dataset.__len__() /
                                   config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))

    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    test_sampler = get_sampler(test_dataset)
    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=config.WORKERS,
                                             pin_memory=True,
                                             sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model = FullModel(model, criterion)
    if distributed:
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            find_unused_parameters=True,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        model = nn.DataParallel(model, device_ids=gpus).cuda()

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':

        params_dict = dict(model.named_parameters())
        if config.TRAIN.NONBACKBONE_KEYWORDS:
            bb_lr = []
            nbb_lr = []
            nbb_keys = set()
            for k, param in params_dict.items():
                if any(part in k
                       for part in config.TRAIN.NONBACKBONE_KEYWORDS):
                    nbb_lr.append(param)
                    nbb_keys.add(k)
                else:
                    bb_lr.append(param)
            print(nbb_keys)
            params = [{
                'params': bb_lr,
                'lr': config.TRAIN.LR
            }, {
                'params': nbb_lr,
                'lr': config.TRAIN.LR * config.TRAIN.NONBACKBONE_MULT
            }]
        else:
            params = [{
                'params': list(params_dict.values()),
                'lr': config.TRAIN.LR
            }]

        optimizer = torch.optim.SGD(
            params,
            lr=config.TRAIN.LR,
            momentum=config.TRAIN.MOMENTUM,
            weight_decay=config.TRAIN.WD,
            nesterov=config.TRAIN.NESTEROV,
        )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))

    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location={'cuda:0': 'cpu'})
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            dct = checkpoint['state_dict']

            model.module.model.load_state_dict({
                k.replace('model.', ''): v
                for k, v in checkpoint['state_dict'].items()
                if k.startswith('model.')
            })
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        if distributed:
            torch.distributed.barrier()

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * extra_epoch_iters

    for epoch in range(last_epoch, end_epoch):

        current_trainloader = extra_trainloader if epoch >= config.TRAIN.END_EPOCH else trainloader
        if current_trainloader.sampler is not None and hasattr(
                current_trainloader.sampler, 'set_epoch'):
            current_trainloader.sampler.set_epoch(epoch)

        # valid_loss, mean_IoU, IoU_array = validate(config,
        #             testloader, model, writer_dict)

        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch - config.TRAIN.END_EPOCH,
                  config.TRAIN.EXTRA_EPOCH, extra_epoch_iters,
                  config.TRAIN.EXTRA_LR, extra_iters, extra_trainloader,
                  optimizer, model, writer_dict)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, epoch_iters,
                  config.TRAIN.LR, num_iters, trainloader, optimizer, model,
                  writer_dict)

        valid_loss, mean_IoU, IoU_array = validate(config, testloader, model,
                                                   writer_dict)

        if args.local_rank <= 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    'checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))
            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                valid_loss, mean_IoU, best_mIoU)
            logging.info(msg)
            logging.info(IoU_array)

    if args.local_rank <= 0:

        torch.save(model.module.state_dict(),
                   os.path.join(final_output_dir, 'final_state.pth'))

        writer_dict['writer'].close()
        end = timeit.default_timer()
        logger.info('Hours: %d' % np.int((end - start) / 3600))
        logger.info('Done')
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    # 用于加快训练速度,同时避免benchmark的随机性
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
        cfg, is_train=True)  # eval()函数执行一个字符串表达式,并返回表达式的值

    # copy model file
    this_dir = os.path.dirname(__file__)  # 取当前路径
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))  # 记录模型日志

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    #model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
    # 多GPU训练
    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()
    regress_loss = RegLoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()
    # Data loading code
    normalize = transforms.Normalize(
        # 使用Imagenet的均值和标准差进行归一化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))  # 图像处理

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
    )

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, regress_loss, optimizer,
              epoch, final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, regress_loss, final_output_dir,
                                  tb_log_dir, writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 30
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    if torch.cuda.is_available():
        train_batch_size = cfg.TRAIN.BATCH_SIZE_PER_GPU * torch.cuda.device_count(
        )
        test_batch_size = cfg.TEST.BATCH_SIZE_PER_GPU * torch.cuda.device_count(
        )
        logger.info("Let's use %d GPUs!" % torch.cuda.device_count())

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
        cfg, is_train=True).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])).cuda()
    # writer_dict['writer'].add_graph(model, (dump_input, ))
    # logger.info(get_model_summary(model, dump_input))

    model = torch.nn.DataParallel(model)
    # model = torch.nn.DataParallel(model, device_ids=cfg.GPUS)

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        cfg=cfg,
        target_type=cfg.MODEL.TARGET_TYPE,
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    ''' Due to imbalance of dataset, adjust sampling weight for each class
        according to class distribution
    '''
    cls_prop = train_dataset.cls_stat / train_dataset.cls_stat.sum()
    cls_weights = 1 / (cls_prop + 0.02)
    str_index = 'Class idx  '
    str_prop = 'Proportion '
    str_weigh = 'Weights    '
    for i in range(len(cls_prop)):
        str_index += '| %5d ' % (i)
        str_prop += '| %5.2f ' % cls_prop[i]
        str_weigh += '| %5.2f ' % cls_weights[i]
    logger.info('Training Data Analysis:')
    logger.info(str_index)
    logger.info(str_prop)
    logger.info(str_weigh)
    sample_list_of_cls = train_dataset.sample_list_of_cls
    sample_list_of_weights = list(
        map(lambda x: cls_weights[x], sample_list_of_cls))
    train_sampler = torch.utils.data.WeightedRandomSampler(
        sample_list_of_weights, len(train_dataset))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        batch_size=train_batch_size,
        # shuffle=cfg.TRAIN.SHUFFLE,
        sampler=train_sampler,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        # batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)
    logger.info("=> Start training...")

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, train_dataset, model, criterion, optimizer,
              epoch, final_output_dir, tb_log_dir, writer_dict)

        torch.save(model.module.state_dict(),
                   final_output_dir + '/epoch-%d.pth' % epoch)
        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)
        logger.info('# Best AP {}'.format(best_perf))

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()