예제 #1
0
파일: main.py 프로젝트: yilonghe/A2Net
def main():
    args = parse_args()
    update_config(args.cfg)
    # create output directory
    if cfg.BASIC.CREATE_OUTPUT_DIR:
        out_dir = os.path.join(cfg.BASIC.ROOT_DIR, cfg.TRAIN.MODEL_DIR)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
    # copy config file
    if cfg.BASIC.BACKUP_CODES:
        backup_dir = os.path.join(cfg.BASIC.ROOT_DIR, cfg.TRAIN.MODEL_DIR,
                                  'code')
        backup_codes(cfg.BASIC.ROOT_DIR, backup_dir, cfg.BASIC.BACKUP_LISTS)
    fix_random_seed(cfg.BASIC.SEED)
    if cfg.BASIC.SHOW_CFG:
        pprint.pprint(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    cudnn.enabled = cfg.CUDNN.ENABLE

    # data loader
    train_dset = TALDataset(cfg, cfg.DATASET.TRAIN_SPLIT)
    train_loader = DataLoader(train_dset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              shuffle=True,
                              drop_last=False,
                              num_workers=cfg.BASIC.WORKERS,
                              pin_memory=cfg.DATASET.PIN_MEMORY)
    val_dset = TALDataset(cfg, cfg.DATASET.VAL_SPLIT)
    val_loader = DataLoader(val_dset,
                            batch_size=cfg.TEST.BATCH_SIZE,
                            shuffle=False,
                            drop_last=False,
                            num_workers=cfg.BASIC.WORKERS,
                            pin_memory=cfg.DATASET.PIN_MEMORY)

    model = LocNet(cfg)
    model.apply(weight_init)
    model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=cfg.TRAIN.LR)
    for epoch in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH + 1):
        loss_train = train(cfg, train_loader, model, optimizer)
        print('epoch %d: loss: %f' % (epoch, loss_train))
        with open(os.path.join(cfg.BASIC.ROOT_DIR, cfg.TRAIN.LOG_FILE),
                  'a') as f:
            f.write("epoch %d, loss: %.4f\n" % (epoch, loss_train))

        # decay lr
        if epoch in cfg.TRAIN.LR_DECAY_EPOCHS:
            decay_lr(optimizer, factor=cfg.TRAIN.LR_DECAY_FACTOR)

        if epoch in cfg.TEST.EVAL_INTERVAL:
            save_model(cfg, epoch=epoch, model=model, optimizer=optimizer)
            out_df_ab, out_df_af = evaluation(val_loader, model, epoch, cfg)
            out_df_list = [out_df_ab, out_df_af]
            final_result_process(out_df_list, epoch, cfg, flag=0)
예제 #2
0
def train_mlp(train_dataset, val_dataset):
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True
    input_size = train_dataset.input_size()
    config = {'input_size': input_size, 'output_size': 12, 'print_freq': 100}
    model = HaptMlpModel(config).cuda()

    weight_tensor = [1 for _ in range(config['output_size'] - 2)] + [10, 10]
    criterion = nn.CrossEntropyLoss(torch.Tensor(weight_tensor)).cuda()
    # criterion = nn.CrossEntropyLoss(torch.Tensor([0.1, 0.1, 0.1, 0.1])).cuda()
    optimizer = torch.optim.Adam(
        # model.parameters(),
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=learning_rate)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=True)
    summary = SummaryWriter()
    best_val_acc = 0
    best_val_f1 = 0
    for epoch in range(total_epoch):
        train(config, train_loader, model, criterion, optimizer, epoch,
              summary)
        val_acc, val_f1 = validate(config, val_loader, model, criterion, epoch,
                                   summary)
        save_checkpoint(model, epoch, optimizer, './checkpoints',
                        'checkpoint_mlp.pth.tar')
        if val_f1 > best_val_f1:
            save_checkpoint(model, epoch, optimizer, './checkpoints',
                            'best_mlp.pth.tar')
            best_val_acc = val_acc
            best_val_f1 = val_f1
    return best_val_acc, best_val_f1
예제 #3
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_builder = importlib.import_module("models." +
                                            cfg.MODEL.NAME).get_fovea_net
    model = model_builder(cfg, is_train=True)

    # xiaofeng add for load parameter
    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)

    # copy model file -- xiaofeng comment it
    # this_dir = os.path.dirname(__file__)
    # shutil.copy2(os.path.join(this_dir, '../models', cfg.MODEL.NAME + '.py'), final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = HybridLoss(roi_weight=cfg.LOSS.ROI_WEIGHT,
                           regress_weight=cfg.LOSS.REGRESS_WEIGHT,
                           use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT,
                           hrnet_only=cfg.TRAIN.HRNET_ONLY).cuda()

    # Data loading code
    # normalize = transforms.Normalize(
    #     mean=[0.134, 0.207, 0.330], std=[0.127, 0.160, 0.239]
    # )
    # train_dataset = importlib.import_module('dataset.'+cfg.DATASET.DATASET).Dataset(
    #     cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
    #     transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    #     ])
    # )
    # valid_dataset = importlib.import_module('dataset.'+cfg.DATASET.DATASET).Dataset(
    #     cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
    #     transforms.Compose([
    #         transforms.ToTensor(),
    #         normalize,
    #     ])
    # )
    #
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset,
    #     batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
    #     shuffle=cfg.TRAIN.SHUFFLE,
    #     num_workers=cfg.WORKERS,
    #     pin_memory=cfg.PIN_MEMORY
    # )
    # valid_loader = torch.utils.data.DataLoader(
    #     valid_dataset,
    #     batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
    #     shuffle=False,
    #     num_workers=cfg.WORKERS,
    #     pin_memory=cfg.PIN_MEMORY
    # )

    db_trains = []
    db_vals = []
    final_full_test = cfg.TRAIN.FULL_DATA
    normalize_1 = transforms.Normalize(mean=[0.282, 0.168, 0.084],
                                       std=[0.189, 0.110, 0.062])
    train_dataset_1 = importlib.import_module('dataset.' +
                                              cfg.DATASET.DATASET).Dataset(
                                                  cfg, cfg.DATASET.ROOT,
                                                  cfg.DATASET.TRAIN_SET_1,
                                                  True,
                                                  transforms.Compose([
                                                      transforms.ToTensor(),
                                                      normalize_1,
                                                  ]))
    db_trains.append(train_dataset_1)

    normalize_2 = transforms.Normalize(mean=[0.409, 0.270, 0.215],
                                       std=[0.288, 0.203, 0.160])
    train_dataset_2 = importlib.import_module('dataset.' +
                                              cfg.DATASET.DATASET).Dataset(
                                                  cfg, cfg.DATASET.ROOT,
                                                  cfg.DATASET.TRAIN_SET_2,
                                                  True,
                                                  transforms.Compose([
                                                      transforms.ToTensor(),
                                                      normalize_2,
                                                  ]))
    db_trains.append(train_dataset_2)

    if final_full_test is True:
        normalize_3 = transforms.Normalize(mean=[0.404, 0.271, 0.222],
                                           std=[0.284, 0.202, 0.163])
        train_dataset_3 = importlib.import_module(
            'dataset.' + cfg.DATASET.DATASET).Dataset(
                cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, True,
                transforms.Compose([
                    transforms.ToTensor(),
                    normalize_3,
                ]))
        db_trains.append(train_dataset_3)

    train_dataset = ConcatDataset(db_trains)
    logger.info("Combined Dataset: Total {} images".format(len(train_dataset)))

    train_batch_size = cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=train_batch_size,
                                               shuffle=cfg.TRAIN.SHUFFLE,
                                               num_workers=cfg.WORKERS,
                                               pin_memory=cfg.PIN_MEMORY)

    normalize = transforms.Normalize(mean=[0.404, 0.271, 0.222],
                                     std=[0.284, 0.202, 0.163])
    val_dataset_1 = importlib.import_module('dataset.' +
                                            cfg.DATASET.DATASET).Dataset(
                                                cfg, cfg.DATASET.ROOT,
                                                cfg.DATASET.TEST_SET, False,
                                                transforms.Compose([
                                                    transforms.ToTensor(),
                                                    normalize,
                                                ]))
    db_vals.append(val_dataset_1)

    if final_full_test is True:
        normalize_1 = transforms.Normalize(mean=[0.282, 0.168, 0.084],
                                           std=[0.189, 0.110, 0.062])
        val_dataset_2 = importlib.import_module('dataset.' +
                                                cfg.DATASET.DATASET).Dataset(
                                                    cfg, cfg.DATASET.ROOT,
                                                    cfg.DATASET.TRAIN_SET_1,
                                                    False,
                                                    transforms.Compose([
                                                        transforms.ToTensor(),
                                                        normalize_1,
                                                    ]))
        db_vals.append(val_dataset_2)

        normalize_2 = transforms.Normalize(mean=[0.409, 0.270, 0.215],
                                           std=[0.288, 0.203, 0.160])
        val_dataset_3 = importlib.import_module('dataset.' +
                                                cfg.DATASET.DATASET).Dataset(
                                                    cfg, cfg.DATASET.ROOT,
                                                    cfg.DATASET.TRAIN_SET_2,
                                                    False,
                                                    transforms.Compose([
                                                        transforms.ToTensor(),
                                                        normalize_2,
                                                    ]))
        db_vals.append(val_dataset_3)

    valid_dataset = ConcatDataset(db_vals)

    logger.info("Val Dataset: Total {} images".format(len(valid_dataset)))

    test_batch_size = cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    logger.info("Train len: {}, batch_size: {}; Test len: {}, batch_size: {}" \
                .format(len(train_loader), train_batch_size, len(valid_loader), test_batch_size))

    best_metric = 1e6
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH

    if cfg.TEST.MODEL_FILE:
        checkpoint_file = cfg.TEST.MODEL_FILE
    else:
        checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        # begin_epoch = checkpoint['epoch']
        begin_epoch = 0  # xiaofeng change it
        best_metric = checkpoint['metric']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    if cfg.TRAIN.LR_EXP:
        # llr=lr∗gamma∗∗epoch
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                              cfg.TRAIN.GAMMA1,
                                                              last_epoch=-1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        start_time = timer()

        lr_scheduler.step()

        # evaluate on validation set
        # lr_metric, hr_metric, final_metric = validate(
        #     cfg, valid_loader, valid_dataset, model, criterion,
        #     final_output_dir, tb_log_dir, writer_dict, db_vals
        # )
        # print("validation before training spent time:")
        # timer(start_time)  # timing ends here for "start_time" variable

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        print("epoch %d train spent time:" % (epoch))
        train_time = timer(
            start_time)  # timing ends here for "start_time" variable

        # if epoch >= int(cfg.TRAIN.END_EPOCH/10):
        # evaluate on validation set
        lr_metric, hr_metric, final_metric = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict, db_vals)

        print("validation spent time:")
        val_time = timer(
            train_time)  # timing ends here for "start_time" variable

        min_metric = min(lr_metric, hr_metric, final_metric)
        if min_metric <= best_metric:
            best_metric = min_metric
            best_model = True
            logger.info('=> epoch [{}] best model result: {}'.format(
                epoch, best_metric))
        else:
            best_model = False

        # xiaofeng changed it
        if best_model is True:
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            # transfer the model to CPU before saving to fix unstable bug:
            # github.com/pytorch/pytorch/issues/10577

            model = model.cpu()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'metric': final_metric,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir)
            model = model.cuda()

            print("saving spent time:")
            end_time = timer(
                val_time)  # timing ends here for "start_time" variable
        elif (epoch % 60 == 0) and (epoch != 0):
            logger.info('=> saving epoch {} checkpoint to {}'.format(
                epoch, final_output_dir))
            # transfer the model to CPU before saving to fix unstable bug:
            # github.com/pytorch/pytorch/issues/10577

            time_str = time.strftime('%Y-%m-%d-%H-%M')
            if cfg.TRAIN.HRNET_ONLY:
                checkpoint_filename = 'checkpoint_HRNET_epoch%d_%s.pth' % (
                    epoch, time_str)
            else:
                checkpoint_filename = 'checkpoint_Hybrid_epoch%d_%s.pth' % (
                    epoch, time_str)
            model = model.cpu()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'best_state_dict': model.module.state_dict(),
                    'metric': final_metric,
                    'optimizer': optimizer.state_dict(),
                }, best_model, final_output_dir, checkpoint_filename)
            model = model.cuda()

    # xiaofeng change
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    if cfg.TRAIN.HRNET_ONLY:
        model_name = 'final_state_HRNET_%s.pth' % (time_str)
    else:
        model_name = 'final_state_Hybrid_%s.pth' % (time_str)

    final_model_state_file = os.path.join(final_output_dir, model_name)
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

    # save a final checkpoint
    model = model.cpu()
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'metric': final_metric,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir, "checkpoint_final_state.pth")
예제 #4
0
파일: train.py 프로젝트: tgkyrie/Face_Xray
def main():
    args = parse_args()

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    nnb = models.nnb.get_nnb(config)  # 不锁定参数  TODO: optimzer 中途添加参数
    # nnb = models.ae.get_ae()
    # nnb = models.fcn.get_fcn(config)
    # 训练时令nnc的softmax不起作用
    nnc = models.nnc.get_nnc(config)

    writer_dict = {
        'writer':
        SummaryWriter(log_dir='./output/facexray/tensorboard/tensorboard' +
                      '_' + datetime.now().strftime('%Y%m%d_%H%M%S')),
        'train_global_steps':
        0,
        'valid_global_steps':
        0,
        'test_global_steps':
        0,
    }

    # log init
    save_dir = os.path.join('./output/facexray/log/log' + '_' +
                            datetime.now().strftime('%Y%m%d_%H%M%S'))
    if os.path.exists(save_dir):
        raise NameError('model dir exists!')
    os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info

    gpus = list(config.GPUS)
    nnb = torch.nn.DataParallel(nnb, device_ids=[0]).cuda()
    nnc = torch.nn.DataParallel(nnc, device_ids=[0]).cuda()

    # define loss function (criterion) and optimizer
    criterion = Loss()

    # 一些参数
    # 初始化optimzer,训练除nnb的原hrnet参数外的参数
    optimizer = get_optimizer(config, [nnb, nnc])  # TODO: 暂时直接全部初始化
    NNB_GRAD = False
    nnb.module.pretrained_grad(NNB_GRAD)
    last_iter = config.TRAIN.BEGIN_ITER
    best_perf = 0.0

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_iter - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_iter - 1)

    # Data loading code
    # transform还没能适用于其他规格,应做成[256, 256, 3]


#     train_dataset = eval('dataset.' + config.DATASET.TRAIN_SET + '.' + config.DATASET.TRAIN_SET)(
#         root=config.DATASET.TRAIN_ROOT, list_name=config.DATASET.TRAIN_LIST, mode='train', Transform='simple')

#     valid_dataset = eval('dataset.' + config.DATASET.EVAL_SET + '.' + config.DATASET.EVAL_SET)(
#         root=config.DATASET.VALID_ROOT, list_name=config.DATASET.VALID_LIST, mode='valid', Transform='simple')

#     test_dataset = eval('dataset.' + config.DATASET.EVAL_SET + '.' + config.DATASET.EVAL_SET)(
#         root=config.DATASET.TEST_ROOT, list_name=config.DATASET.TEST_LIST, mode='test', Transform='simple')
    train_dataset = mydataset(datapath + 'train15k', datapath + 'origin5k')
    valid_dataset = mydataset(datapath + 'generatorBlendedRandomGaussian',
                              datapath + 'origin')
    test_dataset = mydataset(datapath + 'test1k', datapath + 'test_o500')
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    def cycle(loader):
        while True:
            for x in loader:
                yield x
            op = getattr(loader.dataset, "generate", None)
            if callable(op):
                op()

    train_generator = iter(cycle(train_loader))

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    for iteration in range(last_iter, config.TRAIN.END_ITER,
                           config.TRAIN.EVAL_ITER):

        # 前50000次迭代锁定原hrnet层参数训练,后面的迭代训练所有参数
        if not NNB_GRAD and iteration >= 50000:
            if len(gpus) > 0:
                nnb.module.pretrained_grad(True)
            else:
                nnb.pretrained_grad(True)
            NNB_GRAD = True

        # train for one epoch
        train(config,
              train_generator,
              nnb,
              nnc,
              criterion,
              optimizer,
              iteration,
              writer_dict,
              _print,
              lr_scheduler=lr_scheduler)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, nnb, nnc, criterion,
                                  writer_dict, _print)
        test(config, test_loader, nnb, nnc, criterion, writer_dict, _print)

        # 保存目前准确率最高的模型
        # if perf_indicator > best_perf:
        #    best_perf = perf_indicator
        #    torch.save(model.module.state_dict(), './output/BI_dataset/bestfaceXray_'+str(best_perf)+'.pth')
        #    _print('[Save best model] ./output/BI_dataset/bestfaceXray_'+str(best_perf)+'.pth\t')

        iter_now = iteration + config.TRAIN.EVAL_ITER
        if (iteration // config.TRAIN.EVAL_ITER) % 2 == 0:
            torch.save(
                nnb.module.state_dict(),
                './output/BI_dataset2/faceXray_' + str(iter_now) + '.pth')
            torch.save(nnc.module.state_dict(),
                       './output/BI_dataset2/nnc' + str(iter_now) + '.pth')
            _print('[Save model] ./output/BI_dataset2/faceXray_' +
                   str(iter_now) + '.pth\t')
            _print('[Save the last model] ./output/BI_dataset2/nnc' +
                   str(iter_now) + '.pth\t')
        # lr_scheduler.step()

    # 最后的模型
    torch.save(nnb.module.state_dict(), './output/BI_dataset/faceXray.pth')
    torch.save(nnc.module.state_dict(), './output/BI_dataset/nnc.pth')
    _print('[Save the last model] ./output/BI_dataset/faceXray.pth\t')
    _print('[Save the last model] ./output/BI_dataset/nnc.pth\t')
    writer_dict['writer'].close()
예제 #5
0
def main_worker(gpu, ngpus_per_node, args, final_output_dir, tb_log_dir):

    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    print('Init process group: dist_url: {}, world_size: {}, rank: {}'.format(cfg.DIST_URL, args.world_size, args.rank))
    dist.init_process_group(backend=cfg.DIST_BACKEND, init_method=cfg.DIST_URL, world_size=args.world_size, rank=args.rank)

    update_config(cfg, args)

    # setup logger
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=True)
    logger.info(get_model_summary(model, torch.zeros(1, 3, *cfg.MODEL.IMAGE_SIZE)))

    # copy model file
    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED and args.rank % ngpus_per_node == 0):
        this_dir = os.path.dirname(__file__)
        shutil.copy2(os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if not cfg.MULTIPROCESSING_DISTRIBUTED or (cfg.MULTIPROCESSING_DISTRIBUTED and args.rank % ngpus_per_node == 0):
        dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
        writer_dict['writer'].add_graph(model, (dump_input, ))
        # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.MODEL.SYNC_BN:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda(args.gpu)

    # Data loading code
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=(train_sampler is None),
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
        sampler=train_sampler
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    logger.info(train_loader.dataset)

    best_perf = -1
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        
        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`.
        lr_scheduler.step()

        # evaluate on validation set
        perf_indicator = validate(
            args, cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict
        )

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        if not cfg.MULTIPROCESSING_DISTRIBUTED or (
                cfg.MULTIPROCESSING_DISTRIBUTED
                and args.rank == 0
        ):
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state{}.pth.tar'.format(gpu)
    )

    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #6
0
파일: train.py 프로젝트: tgkyrie/Face_Xray
def main():
    args = parse_args()

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_nnb')(config)

    writer_dict = {
        'writer': SummaryWriter(log_dir='./output/facexray'),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = torch.nn.DataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = Loss()

    optimizer = get_optimizer(config, model)

    last_epoch = config.TRAIN.BEGIN_EPOCH

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    # list_name没有单独标注在.yaml文件
    # transform还没能适用于其他规格,应做成[256, 256, 3]
    train_dataset = eval('dataset.' + config.DATASET.DATASET + '.' +
                         config.DATASET.DATASET)(
                             config.DATASET.ROOT, config.DATASET.TRAIN_SET,
                             None, transforms.Compose([transforms.ToTensor()]))

    valid_dataset = eval('dataset.' + config.DATASET.DATASET + '.' +
                         config.DATASET.DATASET)(config.DATASET.ROOT,
                                                 config.DATASET.TEST_SET, None,
                                                 transforms.Compose(
                                                     [transforms.ToTensor()]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # 前50000次迭代锁定原hrnet层参数训练,后面的迭代训练所有参数
        if epoch == 150000:
            for k, v in model.named_parameters():
                v.requires_grad = True

        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              writer_dict)
        # evaluate on validation set
        validate(config, valid_loader, model, criterion, writer_dict)

    torch.save(model.module.state_dict(), './output/BI_dataset/faceXray.pth')
    writer_dict['writer'].close()
예제 #7
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    if args.local_rank == 0:
        # provide the summary of model
        dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
        logger.info(get_model_summary(model.to(device), dump_input.to(device)))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # prepare data
    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
        scale_factor=config.TRAIN.SCALE_FACTOR)

    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler)

    if config.DATASET.EXTRA_TRAIN_SET:
        extra_train_dataset = eval('datasets.' + config.DATASET.DATASET)(
            root=config.DATASET.ROOT,
            list_path=config.DATASET.EXTRA_TRAIN_SET,
            num_samples=None,
            num_classes=config.DATASET.NUM_CLASSES,
            multi_scale=config.TRAIN.MULTI_SCALE,
            flip=config.TRAIN.FLIP,
            ignore_label=config.TRAIN.IGNORE_LABEL,
            base_size=config.TRAIN.BASE_SIZE,
            crop_size=crop_size,
            downsample_rate=config.TRAIN.DOWNSAMPLERATE,
            scale_factor=config.TRAIN.SCALE_FACTOR)

        if distributed:
            extra_train_sampler = DistributedSampler(extra_train_dataset)
        else:
            extra_train_sampler = None

        extra_trainloader = torch.utils.data.DataLoader(
            extra_train_dataset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
            num_workers=config.WORKERS,
            pin_memory=True,
            drop_last=True,
            sampler=extra_train_sampler)

    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        center_crop_test=config.TEST.CENTER_CROP_TEST,
        downsample_rate=1)

    if distributed:
        test_sampler = DistributedSampler(test_dataset)
    else:
        test_sampler = None

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model = FullModel(model, criterion)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        optimizer = torch.optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 model.parameters()),
                'lr': config.TRAIN.LR
            }],
            lr=config.TRAIN.LR,
            momentum=config.TRAIN.MOMENTUM,
            weight_decay=config.TRAIN.WD,
            nesterov=config.TRAIN.NESTEROV,
        )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=lambda storage, loc: storage)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters

    for epoch in range(last_epoch, end_epoch):
        if distributed:
            train_sampler.set_epoch(epoch)
        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch - config.TRAIN.END_EPOCH,
                  config.TRAIN.EXTRA_EPOCH, epoch_iters, config.TRAIN.EXTRA_LR,
                  extra_iters, extra_trainloader, optimizer, model,
                  writer_dict, device)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, epoch_iters,
                  config.TRAIN.LR, num_iters, trainloader, optimizer, model,
                  writer_dict, device)

        valid_loss, mean_IoU, IoU_array = validate(config, testloader, model,
                                                   writer_dict, device)

        if args.local_rank == 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    'checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))

            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                valid_loss, mean_IoU, best_mIoU)
            logging.info(msg)
            logging.info(IoU_array)

            if epoch == end_epoch - 1:
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'final_state.pth'))

                writer_dict['writer'].close()
                end = timeit.default_timer()
                logger.info('Hours: %d' % np.int((end - start) / 3600))
                logger.info('Done')
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
    dataset = config.DATASET.DATASET

    if dataset.startswith('imagenet'):
        model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)
    else:
        model = eval('models.' + config.MODEL.NAME +
                     '.get_cls_net_cifar')(config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model, dump_input))
    # n_flops, n_params = measure_model(model, 32, 32, torch.zeros(1,3,32,32))
    # logger.info("param size = %fMB", n_params[0]/1e6)
    # logger.info("flops = %fM", n_flops[0]/1e6)

    # copy model file
    this_dir = os.path.dirname(__file__)
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)
    shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
            best_model = True

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    if dataset.startswith('imagenet'):
        traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
        valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
            shuffle=True,
            num_workers=config.WORKERS,
            pin_memory=True)

        valid_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
                    transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
                    transforms.ToTensor(),
                    normalize,
                ])),
            batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
            shuffle=False,
            num_workers=config.WORKERS,
            pin_memory=True)

    else:
        CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
        CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
        if config.TRAIN.AUGMENT == 'autoaugment':
            print("==>use autoaugment")
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4, fill=128),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                Cutout_v2(n_holes=1, length=16),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ])
            if config.TRAIN.AUGMENT == 'cutout':
                train_transform.transforms.append(Cutout(16))

        valid_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])

        train_dataset = datasets.CIFAR10(root=config.DATASET.ROOT,
                                         train=True,
                                         download=True,
                                         transform=train_transform)
        valid_dataset = datasets.CIFAR10(root=config.DATASET.ROOT,
                                         train=False,
                                         download=True,
                                         transform=valid_transform)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
            shuffle=True,
            num_workers=config.WORKERS,
            pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
            shuffle=False,
            num_workers=config.WORKERS,
            pin_memory=True)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
            filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #9
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, "train")

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval("models." + config.MODEL.NAME + ".get_pose_net")(
        config, is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, "../lib/models", config.MODEL.NAME + ".py"),
        final_output_dir,
    )

    writer_dict = {
        "writer": SummaryWriter(log_dir=tb_log_dir),
        "train_global_steps": 0,
        "valid_global_steps": 0,
    }

    dump_input = torch.rand((
        config.TRAIN.BATCH_SIZE,
        3,
        config.MODEL.IMAGE_SIZE[1],
        config.MODEL.IMAGE_SIZE[0],
    ))
    writer_dict["writer"].add_graph(model, (dump_input, ), verbose=False)

    gpus = [int(i) for i in config.GPUS.split(",")]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval("dataset." + config.DATASET.DATASET)(
        config,
        config.DATASET.ROOT,
        config.DATASET.TRAIN_SET,
        True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    )
    valid_dataset = eval("dataset." + config.DATASET.DATASET)(
        config,
        config.DATASET.ROOT,
        config.DATASET.TEST_SET,
        False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
    )

    best_perf = 0.0
    best_model = False
    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(
            config,
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            final_output_dir,
            tb_log_dir,
            writer_dict,
        )

        # evaluate on validation set
        perf_indicator = validate(
            config,
            valid_loader,
            valid_dataset,
            model,
            criterion,
            final_output_dir,
            tb_log_dir,
            writer_dict,
        )

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info("=> saving checkpoint to {}".format(final_output_dir))
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "model": get_model_name(config),
                "state_dict": model.state_dict(),
                "perf": perf_indicator,
                "optimizer": optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
        )

    final_model_state_file = os.path.join(final_output_dir,
                                          "final_state.pth.tar")
    logger.info(
        "saving final model state to {}".format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict["writer"].close()
예제 #10
0
def main():
    args = parse_args()
    update_config(cfg, args)

    if args.prevModelDir and args.modelDir:
        # copy pre models for philly
        copy_prev_models(args.prevModelDir, args.modelDir)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    sparsity_criterion = None
    if cfg.DYNCONV.ENABLED:
        sparsity_criterion = dynconv.SparsityCriterion(cfg.DYNCONV.TARGET,
                                                       cfg.TRAIN.END_EPOCH,
                                                       cfg.DYNCONV.WEIGHT)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    logger.info("=> checkpoint file '{}'".format(checkpoint_file))
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))
    else:
        logger.info('=> Did not load checkpoint')

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)
    epoch = begin_epoch
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, sparsity_criterion,
              optimizer, epoch, final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  epoch, writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
            epoch=epoch)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
    return {'epoch': epoch + 1, 'perf': best_perf}
예제 #11
0
def main():
    args = parse_args()

    if args.seed > 0:
        import random
        print('Seeding with', args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = args.local_rank >= 0
    if distributed:
        device = torch.device('cuda:{}'.format(args.local_rank))
        print(device)
        torch.cuda.set_device(device)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    # dump_input = torch.rand(
    #     (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    # )
    # logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

    # copy model file
    if distributed and args.local_rank == 0:
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        batch_size = config.TRAIN.BATCH_SIZE_PER_GPU
    else:
        batch_size = config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus)

    # prepare data
    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
        scale_factor=config.TRAIN.SCALE_FACTOR)

    train_sampler = get_sampler(train_dataset)
    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              shuffle=config.TRAIN.SHUFFLE
                                              and train_sampler is None,
                                              num_workers=config.WORKERS,
                                              pin_memory=True,
                                              drop_last=True,
                                              sampler=train_sampler)

    extra_epoch_iters = 0
    if config.DATASET.EXTRA_TRAIN_SET:
        extra_train_dataset = eval('datasets.' + config.DATASET.DATASET)(
            root=config.DATASET.ROOT,
            list_path=config.DATASET.EXTRA_TRAIN_SET,
            num_samples=None,
            num_classes=config.DATASET.NUM_CLASSES,
            multi_scale=config.TRAIN.MULTI_SCALE,
            flip=config.TRAIN.FLIP,
            ignore_label=config.TRAIN.IGNORE_LABEL,
            base_size=config.TRAIN.BASE_SIZE,
            crop_size=crop_size,
            downsample_rate=config.TRAIN.DOWNSAMPLERATE,
            scale_factor=config.TRAIN.SCALE_FACTOR)
        extra_train_sampler = get_sampler(extra_train_dataset)
        extra_trainloader = torch.utils.data.DataLoader(
            extra_train_dataset,
            batch_size=batch_size,
            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
            num_workers=config.WORKERS,
            pin_memory=True,
            drop_last=True,
            sampler=extra_train_sampler)
        extra_epoch_iters = np.int(extra_train_dataset.__len__() /
                                   config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))

    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        downsample_rate=1)

    test_sampler = get_sampler(test_dataset)
    testloader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=config.WORKERS,
                                             pin_memory=True,
                                             sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model = FullModel(model, criterion)
    if distributed:
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            find_unused_parameters=True,
            device_ids=[args.local_rank],
            output_device=args.local_rank)
    else:
        model = nn.DataParallel(model, device_ids=gpus).cuda()

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':

        params_dict = dict(model.named_parameters())
        if config.TRAIN.NONBACKBONE_KEYWORDS:
            bb_lr = []
            nbb_lr = []
            nbb_keys = set()
            for k, param in params_dict.items():
                if any(part in k
                       for part in config.TRAIN.NONBACKBONE_KEYWORDS):
                    nbb_lr.append(param)
                    nbb_keys.add(k)
                else:
                    bb_lr.append(param)
            print(nbb_keys)
            params = [{
                'params': bb_lr,
                'lr': config.TRAIN.LR
            }, {
                'params': nbb_lr,
                'lr': config.TRAIN.LR * config.TRAIN.NONBACKBONE_MULT
            }]
        else:
            params = [{
                'params': list(params_dict.values()),
                'lr': config.TRAIN.LR
            }]

        optimizer = torch.optim.SGD(
            params,
            lr=config.TRAIN.LR,
            momentum=config.TRAIN.MOMENTUM,
            weight_decay=config.TRAIN.WD,
            nesterov=config.TRAIN.NESTEROV,
        )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))

    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location={'cuda:0': 'cpu'})
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            dct = checkpoint['state_dict']

            model.module.model.load_state_dict({
                k.replace('model.', ''): v
                for k, v in checkpoint['state_dict'].items()
                if k.startswith('model.')
            })
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        if distributed:
            torch.distributed.barrier()

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * extra_epoch_iters

    for epoch in range(last_epoch, end_epoch):

        current_trainloader = extra_trainloader if epoch >= config.TRAIN.END_EPOCH else trainloader
        if current_trainloader.sampler is not None and hasattr(
                current_trainloader.sampler, 'set_epoch'):
            current_trainloader.sampler.set_epoch(epoch)

        # valid_loss, mean_IoU, IoU_array = validate(config,
        #             testloader, model, writer_dict)

        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch - config.TRAIN.END_EPOCH,
                  config.TRAIN.EXTRA_EPOCH, extra_epoch_iters,
                  config.TRAIN.EXTRA_LR, extra_iters, extra_trainloader,
                  optimizer, model, writer_dict)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, epoch_iters,
                  config.TRAIN.LR, num_iters, trainloader, optimizer, model,
                  writer_dict)

        valid_loss, mean_IoU, IoU_array = validate(config, testloader, model,
                                                   writer_dict)

        if args.local_rank <= 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    'checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))
            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                valid_loss, mean_IoU, best_mIoU)
            logging.info(msg)
            logging.info(IoU_array)

    if args.local_rank <= 0:

        torch.save(model.module.state_dict(),
                   os.path.join(final_output_dir, 'final_state.pth'))

        writer_dict['writer'].close()
        end = timeit.default_timer()
        logger.info('Hours: %d' % np.int((end - start) / 3600))
        logger.info('Done')
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
        config, is_train=True)
    logger.info(">>> total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (config.TRAIN.BATCH_SIZE, 3, config.MODEL.IMAGE_SIZE[1],
         config.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)
    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False
    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        """metabatch: 
        args: 1.dataset_num 2.batchsize 3.total_epoch"""
        #####################   METABATCH  #####################################
        dataset_num = len(train_dataset)
        batch_size = config.TRAIN.BATCH_SIZE
        total_epoch = config.TRAIN.END_EPOCH
        logger.info('dataset_size={}, batchsize = {} ,total_epoch = {}'.format(
            dataset_num, batch_size, total_epoch))
        SEU_YS = MetaData_Container(dataset_num, batch_size, total_epoch)
        #########################################################################
        train(config, SEU_YS, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        SEU_YS.Output_CSV_Table()  #每个周期输出表到csv文件并打印

        #########################################################################
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #13
0
def main():
    args = parse_args()
    update_config(cfg, args)

    if args.prevModelDir and args.modelDir:
        # copy pre models for philly
        copy_prev_models(args.prevModelDir, args.modelDir)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    #logger.info(pprint.pformat(args))
    #logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if cfg.MODEL.USE_WARPING_TRAIN:
        if cfg.MODEL.USE_GT_INPUT_TRAIN:
            dump_input = torch.rand(
                (1, 23, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
        else:
            dump_input = torch.rand(
                (1, 6, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    else:
        dump_input = torch.rand(
            (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        #logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        #logger.info("=> loaded checkpoint '{}' (epoch {})".format(
        #    checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    ### importing train/validate functions
    if not cfg.MODEL.SPATIOTEMPORAL_POSE_AGGREGATION:
        from core.function import train
        from core.function import validate
    else:
        from core.function_PoseAgg import train
        from core.function_PoseAgg import validate
    ####

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        if epoch == cfg.TRAIN.END_EPOCH - 1:
            if cfg.MODEL.EVALUATE:
                perf_indicator = validate(cfg, valid_loader, valid_dataset,
                                          model, criterion, final_output_dir,
                                          tb_log_dir, writer_dict)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #14
0
def main():
    args = parse_args()
    update_config(cfg, args)

    if args.prevModelDir and args.modelDir:
        # copy pre models for philly
        copy_prev_models(args.prevModelDir, args.modelDir)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    # if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
    #     logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
    #     checkpoint = torch.load(checkpoint_file)
    #     begin_epoch = checkpoint['epoch']
    #     best_perf = checkpoint['perf']
    #     last_epoch = checkpoint['epoch']
    #     model.load_state_dict(checkpoint['state_dict'])
    #
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     logger.info("=> loaded checkpoint '{}' (epoch {})".format(
    #         checkpoint_file, checkpoint['epoch']))

    # checkpoint = torch.load('output/jd/pose_hrnet/crop_face/checkpoint.pth')
    # model.load_state_dict(checkpoint['state_dict'])

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        # perf_indicator = validate(
        #     cfg, valid_loader, valid_dataset, model, criterion,
        #     final_output_dir, tb_log_dir, writer_dict
        # )
        #
        # if perf_indicator >= best_perf:
        #     best_perf = perf_indicator
        #     best_model = True
        # else:
        #     best_model = False

        # import tqdm
        # import cv2
        # import numpy as np
        # from lib.utils.imutils import im_to_numpy, im_to_torch
        # flip = True
        # full_result = []
        # for i, (inputs,target, target_weight, meta) in enumerate(valid_loader):
        #     with torch.no_grad():
        #         input_var = torch.autograd.Variable(inputs.cuda())
        #         if flip == True:
        #             flip_inputs = inputs.clone()
        #             for i, finp in enumerate(flip_inputs):
        #                 finp = im_to_numpy(finp)
        #                 finp = cv2.flip(finp, 1)
        #                 flip_inputs[i] = im_to_torch(finp)
        #             flip_input_var = torch.autograd.Variable(flip_inputs.cuda())
        #
        #         # compute output
        #         refine_output = model(input_var)
        #         score_map = refine_output.data.cpu()
        #         score_map = score_map.numpy()
        #
        #         if flip == True:
        #             flip_output = model(flip_input_var)
        #             flip_score_map = flip_output.data.cpu()
        #             flip_score_map = flip_score_map.numpy()
        #
        #             for i, fscore in enumerate(flip_score_map):
        #                 fscore = fscore.transpose((1, 2, 0))
        #                 fscore = cv2.flip(fscore, 1)
        #                 fscore = list(fscore.transpose((2, 0, 1)))
        #                 for (q, w) in train_dataset.flip_pairs:
        #                     fscore[q], fscore[w] = fscore[w], fscore[q]
        #                 fscore = np.array(fscore)
        #                 score_map[i] += fscore
        #                 score_map[i] /= 2
        #
        #         # ids = meta['imgID'].numpy()
        #         # det_scores = meta['det_scores']
        #         for b in range(inputs.size(0)):
        #             # details = meta['augmentation_details']
        #             # imgid = meta['imgid'][b]
        #             # print(imgid)
        #             # category = meta['category'][b]
        #             # print(category)
        #             single_result_dict = {}
        #             single_result = []
        #
        #             single_map = score_map[b]
        #             r0 = single_map.copy()
        #             r0 /= 255
        #             r0 += 0.5
        #             v_score = np.zeros(106)
        #             for p in range(106):
        #                 single_map[p] /= np.amax(single_map[p])
        #                 border = 10
        #                 dr = np.zeros((112 + 2 * border, 112 + 2 * border))
        #                 dr[border:-border, border:-border] = single_map[p].copy()
        #                 dr = cv2.GaussianBlur(dr, (7, 7), 0)
        #                 lb = dr.argmax()
        #                 y, x = np.unravel_index(lb, dr.shape)
        #                 dr[y, x] = 0
        #                 lb = dr.argmax()
        #                 py, px = np.unravel_index(lb, dr.shape)
        #                 y -= border
        #                 x -= border
        #                 py -= border + y
        #                 px -= border + x
        #                 ln = (px ** 2 + py ** 2) ** 0.5
        #                 delta = 0.25
        #                 if ln > 1e-3:
        #                     x += delta * px / ln
        #                     y += delta * py / ln
        #                 x = max(0, min(x, 112 - 1))
        #                 y = max(0, min(y, 112 - 1))
        #                 resy = float((4 * y + 2) / 112 * (450))
        #                 resx = float((4 * x + 2) / 112 * (450))
        #                 # resy = float((4 * y + 2) / cfg.data_shape[0] * (450))
        #                 # resx = float((4 * x + 2) / cfg.data_shape[1] * (450))
        #                 v_score[p] = float(r0[p, int(round(y) + 1e-10), int(round(x) + 1e-10)])
        #                 single_result.append(resx)
        #                 single_result.append(resy)
        #             if len(single_result) != 0:
        #                 result = []
        #                 # result.append(imgid)
        #                 j = 0
        #                 while j < len(single_result):
        #                     result.append(float(single_result[j]))
        #                     result.append(float(single_result[j + 1]))
        #                     j += 2
        #                 full_result.append(result)
        model.eval()

        import numpy as np
        from core.inference import get_final_preds
        from utils.transforms import flip_back
        import csv

        num_samples = len(valid_dataset)
        all_preds = np.zeros((num_samples, 106, 3), dtype=np.float32)
        all_boxes = np.zeros((num_samples, 6))
        image_path = []
        filenames = []
        imgnums = []
        idx = 0
        full_result = []
        with torch.no_grad():
            for i, (input, target, target_weight,
                    meta) in enumerate(valid_loader):
                # compute output
                outputs = model(input)
                if isinstance(outputs, list):
                    output = outputs[-1]
                else:
                    output = outputs

                if cfg.TEST.FLIP_TEST:
                    # this part is ugly, because pytorch has not supported negative index
                    # input_flipped = model(input[:, :, :, ::-1])
                    input_flipped = np.flip(input.cpu().numpy(), 3).copy()
                    input_flipped = torch.from_numpy(input_flipped).cuda()
                    outputs_flipped = model(input_flipped)

                    if isinstance(outputs_flipped, list):
                        output_flipped = outputs_flipped[-1]
                    else:
                        output_flipped = outputs_flipped

                    output_flipped = flip_back(output_flipped.cpu().numpy(),
                                               valid_dataset.flip_pairs)
                    output_flipped = torch.from_numpy(
                        output_flipped.copy()).cuda()

                    # feature is not aligned, shift flipped heatmap for higher accuracy
                    if cfg.TEST.SHIFT_HEATMAP:
                        output_flipped[:, :, :, 1:] = \
                            output_flipped.clone()[:, :, :, 0:-1]

                    output = (output + output_flipped) * 0.5

                target = target.cuda(non_blocking=True)
                target_weight = target_weight.cuda(non_blocking=True)

                loss = criterion(output, target, target_weight)

                num_images = input.size(0)
                # measure accuracy and record loss

                c = meta['center'].numpy()
                s = meta['scale'].numpy()
                # print(c.shape)
                # print(s.shape)
                # print(c[:3, :])
                # print(s[:3, :])
                score = meta['score'].numpy()

                preds, maxvals = get_final_preds(cfg,
                                                 output.clone().cpu().numpy(),
                                                 c, s)

                # print(preds.shape)
                for b in range(input.size(0)):
                    result = []
                    # pic_name=meta['image'][b].split('/')[-1]
                    # result.append(pic_name)
                    for points in range(106):
                        # result.append(str(int(preds[b][points][0])) + ' ' + str(int(preds[b][points][1])))
                        result.append(float(preds[b][points][0]))
                        result.append(float(preds[b][points][1]))

                    full_result.append(result)

                all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
                all_preds[idx:idx + num_images, :, 2:3] = maxvals
                # double check this all_boxes parts
                all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
                all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
                all_boxes[idx:idx + num_images, 4] = np.prod(s * 200, 1)
                all_boxes[idx:idx + num_images, 5] = score
                image_path.extend(meta['image'])

                idx += num_images

        # with open('res.csv', 'w', newline='') as f:
        #     writer = csv.writer(f)
        #     writer.writerows(full_result)
        gt = []
        with open("/home/sk49/workspace/cy/jd/val.txt") as f:
            for line in f.readlines():
                rows = list(map(float, line.strip().split(' ')[1:]))
                gt.append(rows)

        error = 0
        for i in range(len(gt)):
            error = NME(full_result[i], gt[i]) + error
        print(error)

        log_file = []
        log_file.append(
            [epoch,
             optimizer.state_dict()['param_groups'][0]['lr'], error])

        with open('log_file.csv', 'a', newline='') as f:
            writer1 = csv.writer(f)
            writer1.writerows(log_file)
            # logger.close()

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                # 'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #15
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.'+config.MODEL.NAME+'.get_cls_net')(
        config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0])
    )
    logger.info(get_model_summary(model, dump_input))

    # copy model file
    this_dir = os.path.dirname(__file__)
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)
    shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = [0,1]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})"
                        .format(checkpoint['epoch']))
            best_model = True
            
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch-1
        )

    # Data loading code
    traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)



    train_dataset = datasets(dataset_root='./Data/',split='train',size= config.MODEL.IMAGE_SIZE[0])
    test_dataset = datasets(dataset_root='./Data/',split='test',size= config.MODEL.IMAGE_SIZE[0])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True
    )

    valid_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True
    )

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
            torch.save(best_model, './best-model.pth')
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': config.MODEL.NAME,
            'state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir, filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #16
0
파일: train.py 프로젝트: sheeranshan/PRTR
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    matcher = build_matcher(cfg.MODEL.NUM_JOINTS)
    weight_dict = {'loss_ce': 1, 'loss_kpts': cfg.MODEL.EXTRA.KPT_LOSS_COEF}
    if cfg.MODEL.EXTRA.AUX_LOSS:
        aux_weight_dict = {}
        for i in range(cfg.MODEL.EXTRA.DEC_LAYERS - 1):
            aux_weight_dict.update(
                {k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)
    criterion = SetCriterion(model.num_classes, matcher, weight_dict, cfg.MODEL.EXTRA.EOS_COEF, [
        'labels', 'kpts', 'cardinality']).cuda()

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(model_key_helper(checkpoint['state_dict']))

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

        if 'train_global_steps' in checkpoint.keys():
            writer_dict['train_global_steps'] = checkpoint['train_global_steps']
            writer_dict['valid_global_steps'] = checkpoint['valid_global_steps']

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch
    )

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        lr_scheduler.step()

        # evaluate on validation set
        perf_indicator = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        ckpt = {
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
            'train_global_steps': writer_dict['train_global_steps'],
            'valid_global_steps': writer_dict['valid_global_steps'],
        }

        if epoch % cfg.SAVE_FREQ == 0:
            save_checkpoint(ckpt, best_model, final_output_dir,
                            filename=f'checkpoint_{epoch}.pth')

        save_checkpoint(ckpt, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #17
0
def main():
    args = parse_args()

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    dist.init_process_group(backend=args.dist_backend,
                            init_method=args.dist_url)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_cls_net')(config)

    dump_input = torch.rand(
        (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]))
    logger.info(get_model_summary(model, dump_input))

    # copy model file
    # this_dir = os.path.dirname(__file__)
    # models_dst_dir = os.path.join(final_output_dir, 'models')
    # if os.path.exists(models_dst_dir):
    #     shutil.rmtree(models_dst_dir)
    # shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    '''
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
    '''
    # Change DP to DDP
    torch.cuda.set_device(args.local_rank)
    model = model.to(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    optimizer = get_optimizer(config, model)

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
            best_model = True

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
            last_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       config.TRAIN.LR_STEP,
                                                       config.TRAIN.LR_FACTOR,
                                                       last_epoch - 1)

    # Data loading code
    traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET)
    valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    '''
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
    '''
    # Change to TSV dataset instance
    train_dataset = TSVInstance(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    # DDP requires DistributedSampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=(train_sampler is None),
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        TSVInstance(
            valdir,
            transforms.Compose([
                transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
                transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)
        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, model, criterion,
                                  final_output_dir, tb_log_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': config.MODEL.NAME,
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            },
            best_model,
            final_output_dir,
            filename='checkpoint.pth.tar')

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #18
0
def main():
    
    # 주요 path 정의
    data_path = './data'
    train_dir = Path(data_path, 'images/train_imgs')
    
    # config 파일을 가져옵니다.
    args = parse_args()
    update_config(cfg, args)

    lr = cfg.TRAIN.LR
    lamb = cfg.LAMB
    test_option = eval(cfg.test_option)
    
    input_w = cfg.MODEL.IMAGE_SIZE[1]
    input_h = cfg.MODEL.IMAGE_SIZE[0]
    
    # 랜덤 요소를 최대한 줄여줌
    RANDOM_SEED = int(cfg.RANDOMSEED)
    np.random.seed(RANDOM_SEED) # cpu vars
    torch.manual_seed(RANDOM_SEED) # cpu  vars
    random.seed(RANDOM_SEED) # Python
    os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED) # Python hash buildin
    torch.backends.cudnn.deterministic = True  #needed
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED) # if use multi-GPU

    
    # log 데이터와 최종 저장위치를 만듭니다.
    logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, f'lr_{str(lr)}', 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)
    
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK

    # annotation 파일을 만듭니다.
    if os.path.isfile(data_path+'/annotations/train_annotation.pkl') == False :
        make_annotations(data_path)
    
    # 쓰려는 모델을 불러옵니다.
    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )
    
    # model의 끝부분 수정 및 초기화 작업을 진행합니다.
    model = initialize_model(model, cfg)
    
    
    # model 파일과 train.py 파일을 copy합니다.
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    
    shutil.copy2(
        os.path.join(this_dir, '../tools', 'train.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }
    
    
    # model을 그래픽카드가 있을 경우 cuda device로 전환합니다.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # loss를 정의합니다.
    criterion = nn.MSELoss().cuda()

    # Data Augumentation을 정의합니다.
    A_transforms = {
        
        'val':
            A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels'])),
        
        'test':
            A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        }
        
    if input_h == input_w :
        
        A_transforms['train'] = A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.OneOf([A.HorizontalFlip(p=1),
                         A.VerticalFlip(p=1),
                         A.Rotate(p=1),
                         A.RandomRotate90(p=1)
                ], p=0.5),
                A.OneOf([A.MotionBlur(p=1),
                         A.GaussNoise(p=1),
                         A.ColorJitter(p=1)
                ], p=0.5),

                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels']))
        
    else :
        A_transforms['train'] = A.Compose([
                A.Resize(input_h, input_w, always_apply=True),
                A.OneOf([A.HorizontalFlip(p=1),
                         A.VerticalFlip(p=1),
                         A.Rotate(p=1),
                ], p=0.5),
                A.OneOf([A.MotionBlur(p=1),
                         A.GaussNoise(p=1)
                         
                ], p=0.5),
                A.OneOf([A.CropAndPad(percent=0.1, p=1),
                         A.CropAndPad(percent=0.2, p=1),
                         A.CropAndPad(percent=0.3, p=1)
                ], p=0.5),

                A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ToTensorV2()
            ], bbox_params=A.BboxParams(format="coco", min_visibility=0.05, label_fields=['class_labels']))
    

    # parameter를 설정합니다.
    batch_size = int(cfg.TRAIN.BATCH_SIZE_PER_GPU)
    test_ratio = float(cfg.TEST_RATIO)
    num_epochs = cfg.TRAIN.END_EPOCH
    
    # earlystopping에 주는 숫자 변수입니다.
    num_earlystop = num_epochs
    
    # torch에서 사용할 dataset을 생성합니다.
    imgs, bbox, class_labels = make_train_data(data_path)

    since = time.time()
    
    """
    # test_option : train, valid로 데이터를 나눌 때 test data를 고려할지 결정합니다.
        * True일 경우 test file을 10% 뺍니다.
        * False일 경우 test file 빼지 않습니다.
    """
    if test_option == True :
        X_train, X_test, y_train, y_test = train_test_split(imgs, bbox, test_size=0.1, random_state=RANDOM_SEED)
        test_dataset = [X_test, y_test]
        with open(final_output_dir+'/test_dataset.pkl', 'wb') as f:
            pickle.dump(test_dataset, f)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=test_ratio, random_state=RANDOM_SEED)
        test_data = Dataset(train_dir, X_test, y_test, data_transforms=A_transforms, class_labels=class_labels, phase='val')
        test_loader = data_utils.DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    else :
        X_train, X_val, y_train, y_val = train_test_split(imgs, bbox, test_size=test_ratio, random_state=RANDOM_SEED)
        
    train_data = Dataset(train_dir, X_train, y_train, data_transforms=A_transforms, class_labels=class_labels, phase='train')
    
    val_data = Dataset(train_dir, X_val, y_val, data_transforms=A_transforms, class_labels=class_labels, phase='val')
    train_loader = data_utils.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = data_utils.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    
    
    # best loss를 판별하기 위한 변수 초기화
    best_perf = 10000000000
    test_loss = None
    best_model = False
    
    # optimizer 정의
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr
    )
    
    # 중간에 학습된 모델이 있다면 해당 epoch에서부터 진행할 수 있도록 만듭니다.
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )
    
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        num_epochs = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))
    
    # lr_scheduler 정의
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=-1
    )
    
    # early stopping하는데 사용하는 count 변수
    count = 0
    val_losses = []
    train_losses = []
    
    # 학습 시작
    for epoch in range(begin_epoch, num_epochs):
        epoch_since = time.time()
        
        lr_scheduler.step()
        
        # train for one epoch
        train_loss = train(cfg, device, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict, lamb=lamb)

        
        # evaluate on validation set
        perf_indicator = validate(
            cfg, device, val_loader, val_data, model, criterion,
            final_output_dir, tb_log_dir, writer_dict, lamb=lamb
        )
        
        # 해당 epoch이 best_model인지 판별합니다. valid 값을 기준으로 결정됩니다.
        if perf_indicator <= best_perf:
            best_perf = perf_indicator
            best_model = True
            count = 0
            
        else:
            best_model = False
            count +=1
            
        
        
        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)
        
        # loss를 저장합니다.
        val_losses.append(perf_indicator)
        train_losses.append(train_loss)
        if count == num_earlystop :
            break
        
        
        epoch_time_elapsed = time.time() - epoch_since
        print(f'epoch : {epoch}' \
                f' train loss : {round(train_loss,3)}' \
                              f' valid loss : {round(perf_indicator,3)}' \
                              f' Elapsed time: {int(epoch_time_elapsed // 60)}m {int(epoch_time_elapsed % 60)}s')
        
    # log 파일 등을 저장합니다.
    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.state_dict(), final_model_state_file)
    writer_dict['writer'].close()

    time_elapsed = time.time() - since
    print('Training and Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best validation loss: {:4f}\n'.format(best_perf))
    
    # test_option이 True일 경우, 떼어난 10% 데이터에 대해 만들어진 모델로 eval을 진행합니다.
    if test_option == True :
        # test data
        model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
            cfg, is_train=True)
        
        model = initialize_model(model, cfg)
        parameters = f'{final_output_dir}/model_best.pth'
        
        model = model.to(device)
        model.load_state_dict(torch.load(parameters))
        
        test_loss = validate(
                cfg, device, test_loader, test_data, model, criterion,
                final_output_dir, tb_log_dir, writer_dict, lamb=lamb
            )
    
    print(f'test loss : {test_loss}')
    
    # loss 결과를 pickle 파일로 따로 저장합니다.
    result_dict = {}
    result_dict['val_loss'] = val_losses
    result_dict['train_loss'] = train_losses
    result_dict['best_loss'] = best_perf
    result_dict['test_loss'] = test_loss
    result_dict['lr'] = lr
    with open(final_output_dir+'/result.pkl', 'wb') as f:
        pickle.dump(result_dict, f)
예제 #19
0
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    if torch.cuda.is_available():
        train_batch_size = cfg.TRAIN.BATCH_SIZE_PER_GPU * torch.cuda.device_count(
        )
        test_batch_size = cfg.TEST.BATCH_SIZE_PER_GPU * torch.cuda.device_count(
        )
        logger.info("Let's use %d GPUs!" % torch.cuda.device_count())

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
        cfg, is_train=True).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # dump_input = torch.rand((1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])).cuda()
    # writer_dict['writer'].add_graph(model, (dump_input, ))
    # logger.info(get_model_summary(model, dump_input))

    model = torch.nn.DataParallel(model)
    # model = torch.nn.DataParallel(model, device_ids=cfg.GPUS)

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        cfg=cfg,
        target_type=cfg.MODEL.TARGET_TYPE,
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    ''' Due to imbalance of dataset, adjust sampling weight for each class
        according to class distribution
    '''
    cls_prop = train_dataset.cls_stat / train_dataset.cls_stat.sum()
    cls_weights = 1 / (cls_prop + 0.02)
    str_index = 'Class idx  '
    str_prop = 'Proportion '
    str_weigh = 'Weights    '
    for i in range(len(cls_prop)):
        str_index += '| %5d ' % (i)
        str_prop += '| %5.2f ' % cls_prop[i]
        str_weigh += '| %5.2f ' % cls_weights[i]
    logger.info('Training Data Analysis:')
    logger.info(str_index)
    logger.info(str_prop)
    logger.info(str_weigh)
    sample_list_of_cls = train_dataset.sample_list_of_cls
    sample_list_of_weights = list(
        map(lambda x: cls_weights[x], sample_list_of_cls))
    train_sampler = torch.utils.data.WeightedRandomSampler(
        sample_list_of_weights, len(train_dataset))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        batch_size=train_batch_size,
        # shuffle=cfg.TRAIN.SHUFFLE,
        sampler=train_sampler,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        # batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        batch_size=test_batch_size,
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)
    logger.info("=> Start training...")

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, train_dataset, model, criterion, optimizer,
              epoch, final_output_dir, tb_log_dir, writer_dict)

        torch.save(model.module.state_dict(),
                   final_output_dir + '/epoch-%d.pth' % epoch)
        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)
        logger.info('# Best AP {}'.format(best_perf))

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #20
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    model_func = 'models.{0}.get_{0}'.format(config.MODEL.NAME)
    model = eval(model_func)(config, pretrained=True)

    dump_input = torch.rand((1, 3, config.MODEL.IMAGE_SIZE[0], config.MODEL.IMAGE_SIZE[1]))
    print(get_model_summary(model, dump_input))

    gpus = list(config.GPUS)
    device = torch.device('cuda:{}'.format(gpus[0]) if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = get_optimizer(config, model)
    if config.FP16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    criterion = {'reid': nn.CrossEntropyLoss(), 'sr': nn.MSELoss()}
    if config.TRAIN.SR_FILTER:
        criterion['sr'] = MSEFilterLoss()

    best_perf = 0.0
    best_model = False
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            best_perf = checkpoint['perf']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if config.FP16:
                amp.load_state_dict(checkpoint['amp'])
            logger.info("=> loaded checkpoint (epoch {})"
                        .format(checkpoint['epoch']))
            best_model = True

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    train_dataset = ImagePersonDataset(config, mode='train')
    train_sampler = RandomIdentitySampler(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        num_instances=4
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus),
        sampler=train_sampler,
        num_workers=config.WORKERS,
        pin_memory=True
    )

    valid_loader = torch.utils.data.DataLoader(
        ImagePersonDataset(config, mode='val'),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU * len(gpus),
        num_workers=config.WORKERS,
        pin_memory=True
    )

    since = time.time()

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        print('Epoch {}/{}'.format(epoch + 1, config.TRAIN.END_EPOCH))

        train(config, train_loader, model, criterion, optimizer, epoch, device,
              writer_dict)

        perf_indicator = validate(config, valid_loader, model, criterion, device,
                                  writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        checkpoint = {
            'epoch': epoch + 1,
            'model': config.MODEL.NAME,
            'state_dict': model.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict()
        }
        if config.FP16:
            checkpoint['amp'] = amp.state_dict()
        save_checkpoint(checkpoint, best_model, final_output_dir,
                        filename='checkpoint.pth.tar')

    h, m, s = get_hms_from_sec(time.time() - since)
    print('=> total training time: {}h {}m {}s'.format(h, m, s))

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #21
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    # print code version info
    repo = Repo('')
    repo_git = repo.git
    working_tree_diff_head = repo_git.diff('HEAD')
    this_commit_hash = repo.commit()
    cur_branches = repo_git.branch('--list')
    logger.info('Current Code Version is {}'.format(this_commit_hash))
    logger.info('Current Branch Info :\n{}'.format(cur_branches))
    logger.info(
        'Working Tree diff with HEAD: \n{}'.format(working_tree_diff_head))

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)
    model = models.multiview_pose_net.get_multiview_pose_net(
        backbone_model, config)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # dump_input = torch.rand(
    #     (config.TRAIN.BATCH_SIZE, 3,  # config.NETWORK.NUM_JOINTS,
    #      config.NETWORK.IMAGE_SIZE[1], config.NETWORK.IMAGE_SIZE[0]))
    # writer_dict['writer'].add_graph(model, dump_input)

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()
    # criterion_fuse = JointsMSELoss(use_target_weight=True).cuda()

    optimizer = get_optimizer(config, model)
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        start_epoch, model, optimizer, ckpt_perf = load_checkpoint(
            model, optimizer, final_output_dir)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        collate_fn=totalcapture_collate,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        collate_fn=totalcapture_collate,
        pin_memory=True)

    best_perf = ckpt_perf
    best_epoch = -1
    best_model = False
    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        extra_param = dict()
        # extra_param['loss2'] = criterion_fuse
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, writer_dict, **extra_param)

        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, writer_dict,
                                  **extra_param)

        logger.info(
            '=> perf indicator at epoch {} is {}. old best is {} '.format(
                epoch, perf_indicator, best_perf))

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
            best_epoch = epoch
            logger.info(
                '====> find new best model at end of epoch {}. (start from 0)'.
                format(epoch))
        else:
            best_model = False
        logger.info(
            'epoch of best validation results is {}'.format(best_epoch))

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

        # save final state at every epoch
        final_model_state_file = os.path.join(
            final_output_dir, 'final_state_ep{}.pth.tar'.format(epoch))
        logger.info(
            'saving final model state to {}'.format(final_model_state_file))
        torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #22
0
def main_worker(rank, args, config, num_gpus):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend='nccl', rank=rank, world_size=num_gpus)
    print('Rank: {} finished initializing, PID: {}'.format(rank, os.getpid()))

    if rank == 0:
        logger, final_output_dir, tb_log_dir = create_logger(
            config, args.cfg, 'train')
        logger.info(pprint.pformat(args))
        logger.info(pprint.pformat(config))
    else:
        final_output_dir = None
        tb_log_dir = None

    # Gracefully kill all subprocesses by command <'kill subprocess 0'>
    signal.signal(signal.SIGTERM, signal_handler)
    if rank == 0:
        logger.info('Rank {} has registerred signal handler'.format(rank))

    # device in current process
    device = torch.device('cuda', rank)

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)
    base_model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)

    model_dict = OrderedDict()
    model_dict['base_model'] = base_model.to(device)

    if config.LOSS.USE_GLOBAL_MI_LOSS:
        global_discriminator = models.discriminator.GlobalDiscriminator(config)
        model_dict['global_discriminator'] = global_discriminator.to(device)
    if config.LOSS.USE_LOCAL_MI_LOSS:
        local_discriminator = models.discriminator.LocalDiscriminator(config)
        model_dict['local_discriminator'] = local_discriminator.to(device)
    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        domain_discriminator = models.discriminator.DomainDiscriminator(config)
        model_dict['domain_discriminator'] = domain_discriminator.to(device)
    if config.LOSS.USE_VIEW_MI_LOSS:
        view_discriminator = models.discriminator.ViewDiscriminator(config)
        model_dict['view_discriminator'] = view_discriminator.to(device)
    if config.LOSS.USE_JOINTS_MI_LOSS:
        joints_discriminator = models.discriminator.JointsDiscriminator(config)
        model_dict['joints_discriminator'] = joints_discriminator.to(device)
    if config.LOSS.USE_HEATMAP_MI_LOSS:
        heatmap_discriminator = models.discriminator.HeatmapDiscriminator(config)
        model_dict['heatmap_discriminator'] = heatmap_discriminator.to(device)

    # copy model files and print model config
    if rank == 0:
        this_dir = os.path.dirname(__file__)
        shutil.copy2(
            os.path.join(this_dir, '../../lib/models', config.MODEL + '.py'),
            final_output_dir)
        shutil.copy2(args.cfg, final_output_dir)
        logger.info(pprint.pformat(model_dict['base_model']))
        if config.LOSS.USE_GLOBAL_MI_LOSS:
            logger.info(pprint.pformat(model_dict['global_discriminator']))
        if config.LOSS.USE_LOCAL_MI_LOSS:
            logger.info(pprint.pformat(model_dict['local_discriminator']))
        if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
            logger.info(pprint.pformat(model_dict['domain_discriminator']))
        if config.LOSS.USE_VIEW_MI_LOSS:
            logger.info(pprint.pformat(model_dict['view_discriminator']))
        if config.LOSS.USE_JOINTS_MI_LOSS:
            logger.info(pprint.pformat(model_dict['joints_discriminator']))
        if config.LOSS.USE_HEATMAP_MI_LOSS:
            logger.info(pprint.pformat(model_dict['heatmap_discriminator']))
        if config.LOSS.USE_GLOBAL_MI_LOSS or config.LOSS.USE_LOCAL_MI_LOSS \
            or config.LOSS.USE_DOMAIN_TRANSFER_LOSS or config.LOSS.USE_VIEW_MI_LOSS \
            or config.LOSS.USE_JOINTS_MI_LOSS or config.LOSS.USE_HEATMAP_MI_LOSS:
            shutil.copy2(
                os.path.join(this_dir, '../../lib/models', 'discriminator.py'),
                final_output_dir)

    # tensorboard writer
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    } if rank == 0 else None

    # dump_input = torch.rand(
    #     (config.TRAIN.BATCH_SIZE, 3,
    #      config.NETWORK.IMAGE_SIZE[1], config.NETWORK.IMAGE_SIZE[0]))
    # writer_dict['writer'].add_graph(model, (dump_input,))

    # first resume, then parallel
    for key in model_dict.keys():
        model_dict[key] = torch.nn.parallel.DistributedDataParallel(model_dict[key], device_ids=[rank], output_device=rank)
        # one by one
        dist.barrier()

    # get optimizer
    optimizer_dict = {}
    optimizer_base_model = get_optimizer(config, model_dict['base_model'])
    optimizer_dict['base_model'] = optimizer_base_model
    if config.LOSS.USE_GLOBAL_MI_LOSS:
        optimizer_global = get_optimizer(config, model_dict['global_discriminator'], is_discriminator=True)
        optimizer_dict['global_discriminator'] = optimizer_global
    if config.LOSS.USE_LOCAL_MI_LOSS:
        optimizer_local = get_optimizer(config, model_dict['local_discriminator'], is_discriminator=True)
        optimizer_dict['local_discriminator'] = optimizer_local
    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        optimizer_domain = get_optimizer(config, model_dict['domain_discriminator'], is_discriminator=True)
        optimizer_dict['domain_discriminator'] = optimizer_domain
    if config.LOSS.USE_VIEW_MI_LOSS:
        optimizer_view = get_optimizer(config, model_dict['view_discriminator'], is_discriminator=True)
        optimizer_dict['view_discriminator'] = optimizer_view
    if config.LOSS.USE_JOINTS_MI_LOSS:
        optimizer_joints = get_optimizer(config, model_dict['joints_discriminator'], is_discriminator=True)
        optimizer_dict['joints_discriminator'] = optimizer_joints
    if config.LOSS.USE_HEATMAP_MI_LOSS:
        optimizer_heatmap = get_optimizer(config, model_dict['heatmap_discriminator'], is_discriminator=True)
        optimizer_dict['heatmap_discriminator'] = optimizer_heatmap

    # resume
    if config.TRAIN.RESUME:
        assert config.TRAIN.RESUME_PATH != '', 'You must designate a path for config.TRAIN.RESUME_PATH, rank: {}'.format(rank)
        if rank == 0:
            logger.info('=> loading model from {}'.format(config.TRAIN.RESUME_PATH))
        # !!! map_location must be cpu, otherwise a lot memory will be allocated on gpu:0.
        state_dict = torch.load(config.TRAIN.RESUME_PATH, map_location=torch.device('cpu'))
        if 'state_dict_base_model' in state_dict:
            if rank == 0:
                logger.info('=> new loading mode')
            for key in model_dict.keys():
                # delete params of the aggregation layer
                if key == 'base_model' and not config.NETWORK.AGGRE:
                    for param_key in list(state_dict['state_dict_base_model'].keys()):
                        if 'aggre_layer' in param_key:
                            state_dict['state_dict_base_model'].pop(param_key)
                model_dict[key].module.load_state_dict(state_dict['state_dict_' + key])
        else:
            if rank == 0:
                logger.info('=> old loading mode')
            # delete params of the aggregation layer
            if not config.NETWORK.AGGRE:
                for param_key in list(state_dict.keys()):
                    if 'aggre_layer' in param_key:
                        state_dict.pop(param_key)
            model_dict['base_model'].module.load_state_dict(state_dict)

    # Traing on server cluster, resumed when interrupted
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.ON_SERVER_CLUSTER:
        start_epoch, model_dict, optimizer_dict, loaded_iteration = load_checkpoint(model_dict, optimizer_dict,
                                                        final_output_dir)
        if args.iteration < loaded_iteration:
            # this training process shold be skipped
            if rank == 0:
                logger.info('=> Skipping training iteration #{}'.format(args.iteration))
            return

    # lr schedulers have different starting points yet share same decay strategy.
    lr_scheduler_dict = {}
    for key in optimizer_dict.keys():
        lr_scheduler_dict[key] = torch.optim.lr_scheduler.MultiStepLR(
            optimizer_dict[key], config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # torch.set_num_threads(8)

    criterion_dict = {}
    criterion_dict['mse_weights'] = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).to(device)
    criterion_dict['mse'] = torch.nn.MSELoss(reduction='mean').to(device)

    if config.LOSS.USE_FUNDAMENTAL_LOSS:
        criterion_dict['fundamental'] = FundamentalLoss(config)

    if config.LOSS.USE_GLOBAL_MI_LOSS or config.LOSS.USE_LOCAL_MI_LOSS:
        criterion_dict['mutual_info'] = MILoss(config, model_dict)

    if config.LOSS.USE_DOMAIN_TRANSFER_LOSS:
        criterion_dict['bce'] = torch.nn.BCELoss().to(device)

    if config.LOSS.USE_VIEW_MI_LOSS:
        criterion_dict['view_mi'] = ViewMILoss(config, model_dict)

    if config.LOSS.USE_JOINTS_MI_LOSS:
        criterion_dict['joints_mi'] = JointsMILoss(config, model_dict)

    if config.LOSS.USE_HEATMAP_MI_LOSS:
        criterion_dict['heatmap_mi'] = HeatmapMILoss(config, model_dict)

    # Data loading code
    if rank == 0:
        logger.info('=> loading dataset')
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        config.DATASET.PSEUDO_LABEL_PATH,
        config.DATASET.NO_DISTORTION)
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        '',
        config.DATASET.NO_DISTORTION)
    # Debug ##################
    # print('len of mixed dataset:', len(train_dataset))
    # print('len of multiview h36m dataset:', len(valid_dataset))

    train_loader, train_sampler = get_training_loader(train_dataset, config)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE,  # no need to multiply len(gpus)
        shuffle=False,
        num_workers=int(config.WORKERS / num_gpus),
        pin_memory=False)

    best_perf = 0
    best_model = False

    dist.barrier()

    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        for lr_scheduler in lr_scheduler_dict.values():
            lr_scheduler.step()

        train_sampler.set_epoch(epoch)

        train(config, train_loader, model_dict, criterion_dict, optimizer_dict, epoch,
                final_output_dir, writer_dict, rank)
        perf_indicator = validate(config, valid_loader, valid_dataset, model_dict,
                                  criterion_dict, final_output_dir, writer_dict, rank)

        if rank == 0:
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

            logger.info('=> saving checkpoint to {}'.format(final_output_dir))

            save_dict = {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'perf': perf_indicator,
                'iteration': args.iteration
            }
            model_state_dict = {}
            optimizer_state_dict = {}
            for key, model in model_dict.items():
                model_state_dict['state_dict_' + key] = model.module.state_dict()
                optimizer_state_dict['optimizer_' + key] = optimizer_dict[key].state_dict()
            save_dict.update(model_state_dict)
            save_dict.update(optimizer_state_dict)
            save_checkpoint(save_dict, best_model, final_output_dir)
        dist.barrier()

    if rank == 0:
        final_model_state_file = os.path.join(final_output_dir,
                                              'final_state.pth.tar')
        logger.info('saving final model state to {}'.format(final_model_state_file))
        torch.save(model_state_dict, final_model_state_file)
        writer_dict['writer'].close()

    print('Rank {} exit'.format(rank))
예제 #23
0
def main():
    args = parse_args()
    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    backbone_model = eval('models.' + config.BACKBONE_MODEL + '.get_pose_net')(
        config, is_train=True)

    model = eval('models.' + config.MODEL + '.get_multiview_pose_net')(
        backbone_model, config)
    print(model)

    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../../lib/models', config.MODEL + '.py'),
        final_output_dir)
    shutil.copy2(args.cfg, final_output_dir)
    logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }


    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)
    start_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        start_epoch, model, optimizer = load_checkpoint(model, optimizer,
                                                        final_output_dir)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + config.DATASET.TRAIN_DATASET)(
        config, config.DATASET.TRAIN_SUBSET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET)(
        config, config.DATASET.TEST_SUBSET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False
    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, writer_dict)

        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': get_model_name(config),
            'state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info('saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #24
0
def main():
    # 对输入参数进行解析
    args = parse_args()
    # 根据输入参数对cfg进行更新
    update_config(cfg, args)

    # 创建logger,用于记录训练过程的打印信息
    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    # 使用GPU的一些相关设置
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # 根据配置文件构建网络
    # 两个模型:models.pose_hrnet和models.pose_resnet,用get_pose_net这个函数可以获得网络结构
    print('models.' + cfg.MODEL.NAME + '.get_pose_net')
    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # copy model file
    # 拷贝lib/models/pose_hrnet.py文件到输出目录之中
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    # 用于训练信息的图形化显示
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # 用于模型的图形化显示
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    # 让模型支持多GPU训练
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    # 用于计算loss
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    # 对输入图像数据进行正则化处理
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # 创建训练以及测试数据的迭代器
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    # 模型加载以及优化策略的相关配置
    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # 循环迭代进行训练
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    # 模型保存
    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    t_checkpoints = cfg.KD.TEACHER  #注意是在student配置文件中修改
    train_type = cfg.KD.TRAIN_TYPE  #注意是在student配置文件中修改
    train_type = get_train_type(train_type, t_checkpoints)
    logger.info('=> train type is {} '.format(train_type))

    if train_type == 'FPD':
        cfg_name = 'student_' + os.path.basename(args.cfg).split('.')[0]
    else:
        cfg_name = os.path.basename(args.cfg).split('.')[0]
    save_yaml_file(cfg_name, cfg, final_output_dir)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # fpd method, default NORMAL
    if train_type == 'FPD':
        tcfg = cfg.clone()
        tcfg.defrost()
        tcfg.merge_from_file(args.tcfg)
        tcfg.freeze()
        tcfg_name = 'teacher_' + os.path.basename(args.tcfg).split('.')[0]
        save_yaml_file(tcfg_name, tcfg, final_output_dir)
        # teacher model
        tmodel = eval('models.' + tcfg.MODEL.NAME + '.get_pose_net')(
            tcfg, is_train=False)

        load_checkpoint(t_checkpoints,
                        tmodel,
                        strict=True,
                        model_info='teacher_' + tcfg.MODEL.NAME)

        tmodel = torch.nn.DataParallel(tmodel, device_ids=cfg.GPUS).cuda()
        # define kd_pose loss function (criterion) and optimizer
        kd_pose_criterion = JointsMSELoss(
            use_target_weight=tcfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    if cfg.TRAIN.CHECKPOINT:
        load_checkpoint(cfg.TRAIN.CHECKPOINT,
                        model,
                        strict=True,
                        model_info='student_' + cfg.MODEL.NAME)
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # you can choose or replace pose_loss and kd_pose_loss type, including mse,kl,ohkm loss ect
    # define pose loss function (criterion) and optimizer
    pose_criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, tmodel, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)
    validate(cfg, valid_loader, valid_dataset, model, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # fpd method, default NORMAL
        if train_type == 'FPD':
            # train for one epoch
            fpd_train(cfg, train_loader, model, tmodel, pose_criterion,
                      kd_pose_criterion, optimizer, epoch, final_output_dir,
                      tb_log_dir, writer_dict)
        else:
            # train for one epoch
            train(cfg, train_loader, model, pose_criterion, optimizer, epoch,
                  final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  pose_criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():
    # convert to train mode
    config.MODE = 'train'
    extra()

    # create a logger
    logger = create_logger(config, 'train')

    # logging configurations
    logger.info(pprint.pformat(config))

    # random seed
    if config.IF_DETERMINISTIC:
        torch.manual_seed(config.RANDOM_SEED_TORCH)
        config.CUDNN.DETERMINISTIC = True
        config.CUDNN.BENCHMARK = False
        np.random.seed(config.RANDOM_SEED_NUMPY)
        random.seed(config.RANDOM_SEED_RANDOM)
    else:
        logger.info('torch random seed: {}'.format(torch.initial_seed()))

        seed = random.randint(0, 2**32)
        np.random.seed(seed)
        logger.info('numpy random seed: {}'.format(seed))

        seed = random.randint(0, 2**32)
        random.seed(seed)
        logger.info('random random seed: {}'.format(seed))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    # create a model
    gpus = [int(i) for i in config.GPUS.split(',')]

    model_rgb = create_model()
    if config.TRAIN.RESUME_RGB:
        model_rgb.my_load_state_dict(torch.load(config.TRAIN.STATE_DICT_RGB),
                                     strict=True)

    model_rgb = model_rgb.cuda(gpus[0])
    model_rgb = torch.nn.DataParallel(model_rgb, device_ids=gpus)

    model_flow = create_model()
    if config.TRAIN.RESUME_FLOW:
        model_flow.my_load_state_dict(torch.load(config.TRAIN.STATE_DICT_FLOW),
                                      strict=True)

    model_flow = model_flow.cuda(gpus[0])
    model_flow = torch.nn.DataParallel(model_flow, device_ids=gpus)

    # create a conditional-vae
    cvae_rgb = create_cvae()
    cvae_rgb = cvae_rgb.cuda(gpus[0])
    cvae_rgb = torch.nn.DataParallel(cvae_rgb, device_ids=gpus)

    cvae_flow = create_cvae()
    cvae_flow = cvae_flow.cuda(gpus[0])
    cvae_flow = torch.nn.DataParallel(cvae_flow, device_ids=gpus)

    # create an optimizer
    optimizer_rgb = create_optimizer(config, model_rgb)
    optimizer_flow = create_optimizer(config, model_flow)
    optimizer_cvae_rgb = create_optimizer(config, cvae_rgb)
    optimizer_cvae_flow = create_optimizer(config, cvae_flow)

    # create a learning rate scheduler
    lr_scheduler_rgb = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_rgb,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_flow = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_flow,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_cvae_rgb = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_cvae_rgb,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_cvae_flow = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_cvae_flow,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)

    # load data
    train_dataset_rgb = get_dataset(mode='train', modality='rgb')
    train_dataset_flow = get_dataset(mode='train', modality='flow')
    test_dataset_rgb = get_dataset(mode='test', modality='rgb')
    test_dataset_flow = get_dataset(mode='test', modality='flow')

    train_loader_rgb = torch.utils.data.DataLoader(
        train_dataset_rgb,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True)
    train_loader_flow = torch.utils.data.DataLoader(
        train_dataset_flow,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True)
    test_loader_rgb = torch.utils.data.DataLoader(
        test_dataset_rgb,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)
    test_loader_flow = torch.utils.data.DataLoader(
        test_dataset_flow,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # training and validating

    best_perf = 0

    best_model_rgb = create_model()
    best_model_rgb = best_model_rgb.cuda(gpus[0])

    best_model_flow = create_model()
    best_model_flow = best_model_flow.cuda(gpus[0])

    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH,
                       config.TRAIN.TEST_EVERY_EPOCH):
        # train rgb for **config.TRAIN.TEST_EVERY_EPOCH** epochs
        train(train_loader_rgb, model_rgb, cvae_rgb, optimizer_rgb,
              optimizer_cvae_rgb, epoch, config.TRAIN.TEST_EVERY_EPOCH, 'rgb')

        # evaluate on validation set
        result_file_path_rgb = test_final(test_dataset_rgb, model_rgb.module,
                                          test_dataset_flow, best_model_flow)
        perf_indicator = eval_mAP(config.DATASET.GT_JSON_PATH,
                                  result_file_path_rgb)

        if best_perf < perf_indicator:
            logger.info("(rgb) new best perf: {:3f}".format(perf_indicator))
            best_perf = perf_indicator
            best_model_rgb.my_load_state_dict(model_rgb.state_dict(),
                                              strict=True)

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf))))
            torch.save(
                best_model_rgb.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf)))

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf))))
            torch.save(
                best_model_flow.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf)))

        # lr_scheduler_rgb.step(perf_indicator)
        # lr_scheduler_cvae_rgb.step()

        # train flow for **config.TRAIN.TEST_EVERY_EPOCH** epochs
        train(train_loader_flow, model_flow, cvae_flow, optimizer_flow,
              optimizer_cvae_flow, epoch, config.TRAIN.TEST_EVERY_EPOCH,
              'flow')

        # evaluate on validation set
        result_file_path_flow = test_final(test_dataset_rgb, best_model_rgb,
                                           test_dataset_flow,
                                           model_flow.module)
        perf_indicator = eval_mAP(config.DATASET.GT_JSON_PATH,
                                  result_file_path_flow)

        if best_perf < perf_indicator:
            logger.info("(flow) new best perf: {:3f}".format(perf_indicator))
            best_perf = perf_indicator
            best_model_flow.my_load_state_dict(model_flow.state_dict(),
                                               strict=True)

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf))))
            torch.save(
                best_model_rgb.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf)))

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf))))
            torch.save(
                best_model_flow.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf)))
예제 #27
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    args = parse_args()
    print('out')
    print(args)

    reset_config(config, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
        config, is_train=True)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
        final_output_dir)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (config.TRAIN.BATCH_SIZE, 3, config.MODEL.IMAGE_SIZE[1],
         config.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)

    gpus = [int(i) for i in config.GPUS.split(',')]
    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT).cuda()

    optimizer = get_optimizer(config, model)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
        config, config.DATASET.ROOT, config.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    best_perf = 0.0
    best_model = False

    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
        lr_scheduler.step()

        #print("model check!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        #for i,p in enumerate(model.parameters()):
        #    print(p.requires_grad)

        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                  criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator > best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth.tar')
    logger.info(
        'saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #28
0
def main():

    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    #specify which gpu to use
    logger, final_output_dir, tb_log_dir = \
        utils.create_logger(config, args.cfg, 'train')

    # model = torchvision.models.resnet18(pretrained=config.MODEL.PRETRAINED)
    # num_ftrs = model.fc.in_features
    # model.fc = nn.Sequential(
    #     nn.Dropout(0.5),
    #     nn.Linear(num_ftrs, config.MODEL.OUTPUT_SIZE[0]))

    model = ResModel(config) 

    

    # copy model files
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
    # loss
    pos_weight = torch.tensor([2.6, 3.4, 3.0, 1.2, 1.1, 1.0, 1.1, 1.2, 3.4, 1.7, 3.6, 3.8], dtype=torch.float32)
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight).cuda()  #
    criterion_val = torch.nn.BCEWithLogitsLoss().cuda()

    optimizer = utils.get_optimizer(config, model)
   
    last_epoch = config.TRAIN.BEGIN_EPOCH
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'latest.pth')
        if os.path.islink(model_state_file):
            checkpoint = torch.load(model_state_file)
            last_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})"
                  .format(checkpoint['epoch']))
        else:
            print("=> no checkpoint found")

    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    dataset_type = get_dataset(config)

    train_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=True),
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY)

    val_loader = DataLoader(
        dataset=dataset_type(config,
                             is_train=False),
        batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY
    )

    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
        lr_scheduler.step()
        function.train(config, train_loader, model, criterion,
                       optimizer, epoch, writer_dict)

        # evaluate
        predictions = function.validate(config, val_loader, model,
                                        criterion_val, epoch, writer_dict)

        if epoch % 5 == 0:
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            torch.save(model.module.state_dict(), os.path.join(final_output_dir, 'checkpoint_{}.pth'.format(epoch)))
        # utils.save_checkpoint(
        #     {"state_dict": model,
        #      "epoch": epoch + 1,
        #      "optimizer": optimizer.state_dict(),
        #      }, predictions, final_output_dir, 'checkpoint_{}.pth'.format(epoch))

    final_model_state_file = os.path.join(final_output_dir,
                                          'final_state.pth')
    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    # 用于加快训练速度,同时避免benchmark的随机性
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
        cfg, is_train=True)  # eval()函数执行一个字符串表达式,并返回表达式的值

    # copy model file
    this_dir = os.path.dirname(__file__)  # 取当前路径
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))  # 记录模型日志

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    #model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
    # 多GPU训练
    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()
    regress_loss = RegLoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()
    # Data loading code
    normalize = transforms.Normalize(
        # 使用Imagenet的均值和标准差进行归一化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))  # 图像处理

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY,
    )

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, regress_loss, optimizer,
              epoch, final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  criterion, regress_loss, final_output_dir,
                                  tb_log_dir, writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
예제 #30
0
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    if args.local_rank == 0:
        logger.info(model)
        tot_params = sum(p.numel() for p in model.parameters()) / 1000000.0
        logger.info(f">>> total params: {tot_params:.2f}M")

        # provide the summary of model
        dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[0], config.TRAIN.IMAGE_SIZE[1]))
        logger.info(get_model_summary(model.to(device), dump_input.to(device)))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # prepare data
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=tuple(config.TRAIN.IMAGE_SIZE),  # (height, width)
        scale_factor=config.TRAIN.SCALE_FACTOR)

    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler)

    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=tuple(config.TEST.IMAGE_SIZE),  # (height, width)
    )

    if distributed:
        test_sampler = DistributedSampler(test_dataset)
    else:
        test_sampler = None

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     weight=train_dataset.class_weights,
                                     thresh=config.LOSS.OHEMTHRESH,
                                     min_kept=config.LOSS.OHEMKEEP)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model_state_file = config.MODEL.PRETRAINED
    logger.info('=> Loading model from {}'.format(model_state_file))
    pretrained_dict = torch.load(model_state_file)
    model_dict = model.state_dict()
    pretrained_dict = {
        k[6:]: v
        for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()
    }
    for k, _ in pretrained_dict.items():
        logger.info('=> Loading {} from pretrained model'.format(k))
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model = FullModel(model, criterion)
    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    if distributed:
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    # optimizer
    optimizer = get_optimizer(config, model)

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=lambda storage, loc: storage)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters

    # learning rate scheduler
    lr_scheduler_dict = {
        'optimizer': optimizer,
        'milestones': [s * epoch_iters for s in config.TRAIN.LR_STEP],
        'gamma': config.TRAIN.LR_FACTOR,
        'max_iters': num_iters,
        'last_epoch': last_epoch,
        'epoch_iters': epoch_iters
    }
    lr_scheduler = get_lr_scheduler(config.TRAIN.LR_SCHEDULER,
                                    **lr_scheduler_dict)

    for epoch in range(last_epoch, end_epoch):
        if distributed:
            train_sampler.set_epoch(epoch)
        train(config, epoch, end_epoch, epoch_iters, trainloader, optimizer,
              lr_scheduler, model, writer_dict, device)

        valid_loss, mean_IoU = validate(config, testloader, model, writer_dict,
                                        device)

        if args.local_rank == 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    '/checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))

            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = f'Loss: {valid_loss:.4f}, MeanIU: {mean_IoU: 4.4f}, \
                        Best_mIoU: {best_mIoU: 4.4f}'

            logger.info(msg)

            if epoch == end_epoch - 1:
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'final_state.pth'))

                writer_dict['writer'].close()
                end = timeit.default_timer()
                logger.info(f'Hours: {np.int((end-start)/3600)}')
                logger.info('Done!')