def main():
    args = parse_args()
    config_path = args.config_file_path

    config = get_config(config_path, new_keys_allowed=True)

    config.defrost()
    config.experiment_dir = os.path.join(config.log_dir, config.experiment_name)
    config.tb_dir = os.path.join(config.experiment_dir, 'tb')
    config.model.best_checkpoint_path = os.path.join(config.experiment_dir, 'best_checkpoint.pt')
    config.model.last_checkpoint_path = os.path.join(config.experiment_dir, 'last_checkpoint.pt')
    config.config_save_path = os.path.join(config.experiment_dir, 'segmentation_config.yaml')
    config.freeze()

    init_experiment(config)
    set_random_seed(config.seed)

    train_dataset = make_dataset(config.train.dataset)
    train_loader = make_data_loader(config.train.loader, train_dataset)

    val_dataset = make_dataset(config.val.dataset)
    val_loader = make_data_loader(config.val.loader, val_dataset)

    device = torch.device(config.device)
    model = make_model(config.model).to(device)

    optimizer = make_optimizer(config.optim, model.parameters())
    scheduler = None

    loss_f = make_loss(config.loss)

    early_stopping = EarlyStopping(
        **config.stopper.params
    )

    train_writer = SummaryWriter(log_dir=os.path.join(config.tb_dir, 'train'))
    val_writer = SummaryWriter(log_dir=os.path.join(config.tb_dir, 'val'))

    for epoch in range(1, config.epochs + 1):
        print(f'Epoch {epoch}')
        train_metrics = train(model, optimizer, train_loader, loss_f, device)
        write_metrics(epoch, train_metrics, train_writer)
        print_metrics('Train', train_metrics)

        val_metrics = val(model, val_loader, loss_f, device)
        write_metrics(epoch, val_metrics, val_writer)
        print_metrics('Val', val_metrics)

        early_stopping(val_metrics['loss'])
        if config.model.save and early_stopping.counter == 0:
            torch.save(model.state_dict(), config.model.best_checkpoint_path)
            print('Saved best model checkpoint to disk.')
        if early_stopping.early_stop:
            print(f'Early stopping after {epoch} epochs.')
            break

        if scheduler:
            scheduler.step()

    train_writer.close()
    val_writer.close()

    if config.model.save:
        torch.save(model.state_dict(), config.model.last_checkpoint_path)
        print('Saved last model checkpoint to disk.')
Example #2
0
def train():

    # 1、make dataloader
    # prepare train,val img_info list, elem is tuple;
    train_loader, val_loader, num_class = cifar_make_data_loader(cfg)

    # 2、make model
    model = build_model(cfg, num_class)

    # 3、 make optimizer
    optimizer = make_optimizer(cfg, model)

    # 4、 make lr_scheduler
    scheduler = make_lr_scheduler(cfg, optimizer)

    # 5、make loss: default use softmax loss
    loss_fn = make_loss(cfg, num_class)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAIN_PATH != ""
    parallel = cfg.MODEL.PARALLEL

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'loss', 'acc', 'loss',
        'acc', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CDNet.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("the model size is {:.2} M".format(size))
    logger.info("Starting Training CDNetwork")
    best_acc = 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if parallel:
        model = nn.DataParallel(model)

    if use_gpu:
        model = model.to(device)

    for epoch in range(epochs):

        scheduler.step()
        lr = scheduler.get_lr()[0]
        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < model.start_epoch:
            continue

        # rest the record
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)

            res = model(imgs)

            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

            model.eval()
            logger.info("begin eval the model")
            val_acc = AverageMeter()
            with torch.no_grad():

                for vi, batch in enumerate(val_loader):

                    imgs, labels = batch

                    if use_gpu:
                        imgs = imgs.to(device)
                        labels = labels.to(device)

                    res = model(imgs)
                    _, acc = compute_loss_acc(res, labels)
                    val_acc.update(acc)

                logger.info("validation results at epoch:{}".format(epoch + 1))
                logger.info("acc:{:.2%}".format(val_acc.avg))

                # determine whether current model is the best
                if val_acc.avg > best_acc:
                    logger.info("get a new best acc")
                    best_acc = val_acc.avg
                    is_best = True

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [
                        format(val_acc.avg * 100, '.2f'),
                        format(avg_loss.avg, '.3f')
                    ]
                    values.extend(val)

        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            if parallel:
                torch.save(
                    model.module.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            else:
                torch.save(
                    model.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                if parallel:
                    torch.save(model.module.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                else:
                    torch.save(model.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                logger.info("best_checkpoint was saved")
                is_best = False

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)
    logger.info("best_acc:{:.2%}".format(best_acc))
    logger.info("Ending training CDNetwork on cifar")
Example #3
0
def train():
    """
	# get an image for test the model 
	train_transform = build_transforms(cfg, is_train = True)
	imgs = get_image("1.jpg")
	img_tensor = train_transform(imgs[0])
	# c,h,w = img_tensor.shape
	# img_tensor = img_tensor.view(-1,c,h,w)
	# add an axis
	img_tensor = img_tensor.unsqueeze(0)
	"""
    # 1、make dataloader
    train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
    #print("num_query:{},num_class:{}".format(num_query,num_class))

    # 2、make model
    model = build_model(cfg, num_class)

    # model.eval()
    # x = model(img_tensor)
    # print(x.shape)
    # 3、 make optimizer
    optimizer = make_optimizer(cfg, model)

    # 4、 make lr_scheduler
    scheduler = make_lr_scheduler(cfg, optimizer)

    # 5、 make loss_func
    if cfg.MODEL.PCB_NECK:
        # make loss specificially for pcb
        loss_func = get_softmax_triplet_loss_fn(cfg, num_class)
    else:
        loss_func = make_loss(cfg, num_class)

    # get paramters
    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CHECKPOINT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.ROOT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    use_gpu = device == "cuda"
    use_neck = cfg.MODEL.NECK or cfg.MODEL.LEARN_REGION
    # how many batch for each log
    batch_size = cfg.SOLVER.IMGS_PER_BATCH
    batch_num = len(train_loader)

    log_iters = batch_num // log_period
    pretrained = cfg.MODEL.PRETRAIN_PATH != ''
    parallel = cfg.MODEL.PARALLEL
    grad_clip = cfg.DARTS.GRAD_CLIP

    feat_norm = cfg.TEST.FEAT_NORM
    ckpt_save_path = cfg.OUTPUT.ROOT_DIR + cfg.OUTPUT.CKPT_DIR
    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.ROOT_DIR + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger('MobileNetReID.train')

    # count parameter
    size = count_parameters(model)
    logger.info("the param number of the model is {:.2f} M".format(size))

    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("Start training")

    #count = 183, x, y = batch -> 11712 for train
    if pretrained:
        start_epoch = model.start_epoch

    if parallel:
        model = nn.DataParallel(model)

    if use_gpu:
        # model = nn.DataParallel(model)
        model.to(device)

    # save the best model
    best_mAP, best_r1 = 0., 0.
    is_best = False
    # batch : img, pid, camid, img_path
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()
    global_avg_time.reset()
    for epoch in range(epochs):
        scheduler.step()

        if pretrained and epoch < start_epoch - 1:
            continue

        model.train()
        # sum_loss, sum_acc = 0., 0.
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()
        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)

            res = model(imgs)
            # score, feat = model(imgs)
            # loss = loss_func(score, feat, labels)
            loss, acc = compute_loss_acc(use_neck, res, labels, loss_func)

            loss.backward()
            if grad_clip != 0:
                nn.utils.clip_grad_norm(model.parameters(), grad_clip)

            optimizer.step()

            optimizer.zero_grad()

            # acc = (score.max(1)[1] == labels).float().mean()

            # sum_loss += loss
            # sum_acc += acc
            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            #log the info
            if (i + 1) % log_iters == 0:

                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        lr = optimizer.state_dict()['param_groups'][0]['lr']
        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)
        # change the lr

        # eval the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) == epochs:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu, feat_norm=feat_norm)

            with torch.no_grad():

                for vi, batch in enumerate(val_loader):

                    imgs, labels, camids = batch

                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch:{}".format(epoch + 1))
                logger.info("mAP:{:.2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether cur model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # we hope that eval_period == ckpt_period or eval_period == k* ckpt_period where k is int
        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:

            if parallel:
                torch.save(
                    model.module.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            else:
                torch.save(
                    model.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))

            logger.info("checkpoint {} saved !".format(epoch + 1))

            if is_best:
                if parallel:
                    torch.save(model.module.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                else:
                    torch.save(model.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                logger.info("best checkpoint was saved")
                is_best = False

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("training is end, time for per imgs is {} ms".format(
        global_avg_time.avg * 1000))
Example #4
0
def train():

    use_gpu = cfg.MODEL.DEVICE == "cuda"
    # 1、make dataloader
    train_loader, val_loader, test_loader, num_query, num_class = darts_make_data_loader(
        cfg)
    # print(num_query, num_class)

    # 2、make model
    model = CNetwork(num_class, cfg)
    # tensor = torch.randn(2, 3, 256, 128)
    # res = model(tensor)
    # print(res[0].size()) [2, 751]

    # 3、make optimizer
    optimizer = make_optimizer(cfg, model)
    # make architecture optimizer
    arch_optimizer = torch.optim.Adam(
        model._arch_parameters(),
        lr=cfg.SOLVER.ARCH_LR,
        betas=(0.5, 0.999),
        weight_decay=cfg.SOLVER.ARCH_WEIGHT_DECAY)

    # 4、make lr scheduler
    lr_scheduler = make_lr_scheduler(cfg, optimizer)
    # make lr scheduler
    arch_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        arch_optimizer, [80, 160], 0.1)

    # 5、make loss
    loss_fn = darts_make_loss(cfg)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAINED != ""

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CNet_Search.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("the param number of the model is {:.2f} M".format(size))

    logger.info("Starting Search CNetwork")

    best_mAP, best_r1 = 0., 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if use_gpu:
        model = model.to(device)

    if pretrained:
        logger.info("load self pretrained chekpoint to init")
        model.load_pretrained_model(cfg.MODEL.PRETRAINED)
    else:
        logger.info("use kaiming init to init the model")
        model.kaiming_init_()

    for epoch in range(epochs):

        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        # architect lr.step
        arch_lr_scheduler.step()

        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < model.start_epoch:
            continue

        # print(epoch)
        # exit(1)
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch
            val_imgs, val_labels = next(iter(val_loader))

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)
                val_imgs = val_imgs.to(device)
                val_labels = val_labels.to(device)

            # 1、 update the weights
            optimizer.zero_grad()
            res = model(imgs)

            # loss = loss_fn(scores, feats, labels)
            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            # 2、update the alpha
            arch_optimizer.zero_grad()
            res = model(val_imgs)

            val_loss, val_acc = compute_loss_acc(res, val_labels, loss_fn)
            val_loss.backward()
            arch_optimizer.step()

            # compute the acc
            # acc = (scores.max(1)[1] == labels).float().mean()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu)

            with torch.no_grad():
                for vi, batch in enumerate(test_loader):
                    # break
                    # print(len(batch))
                    imgs, labels, camids = batch
                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch {}".format(epoch + 1))
                logger.info("mAP:{:2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether current model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            torch.save(model.state_dict(),
                       ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            model._parse_genotype(file=ckpt_save_path +
                                  "genotype_{}.json".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                torch.save(model.state_dict(),
                           ckpt_save_path + "best_ckpt.pth")
                model._parse_genotype(file=ckpt_save_path +
                                      "best_genotype.json")
                logger.info("best_checkpoint was saved")
                is_best = False
        # exit(1)

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("Ending Search CNetwork")
def train():

    # 1、make dataloader
    train_loader, val_loader, num_class = imagenet_make_data_loader(cfg)
    #print("num_query:{},num_class:{}".format(num_query,num_class))

    # 2、make model
    model = build_model(cfg, num_class)

    # 3、 make optimizer
    optimizer = make_optimizer(cfg, model)

    # 4、 make lr_scheduler
    scheduler = make_lr_scheduler(cfg, optimizer)

    # 5、 make loss_func
    # directly use F.cross_entropy

    # get paramters
    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CHECKPOINT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.ROOT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    use_gpu = device == "cuda"

    # how many batch for each log
    batch_size = cfg.SOLVER.IMGS_PER_BATCH
    dataset = train_loader.dataset

    batch_num = len(dataset) // batch_size
    print("batch number: ", batch_num)

    log_iters = batch_num // log_period
    # print(log_iters)
    # exit(1)

    pretrained = cfg.MODEL.PRETRAIN_PATH != ''
    parallel = cfg.MODEL.PARALLEL

    ckpt_save_path = cfg.OUTPUT.ROOT_DIR + cfg.OUTPUT.CKPT_DIR
    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    logger = logging.getLogger('MobileNetReID.train')

    # count parameter
    size = count_parameters(model)
    logger.info("the param number of the model is {:.2f} M".format(size))

    logger.info("Start training")

    #count = 183, x, y = batch -> 11712 for train
    if pretrained:
        start_epoch = model.start_epoch

    if parallel:
        model = nn.DataParallel(model)

    if use_gpu:
        # model = nn.DataParallel(model)
        model.to(device)

    is_best = False
    best_acc = 0.
    # batch : img, pid, camid, img_path
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()
    global_avg_time.reset()
    for epoch in range(epochs):
        scheduler.step()

        if pretrained and epoch < start_epoch - 1:
            continue

        model.train()
        # sum_loss, sum_acc = 0., 0.
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()
        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)

            scores = model(imgs)
            loss = F.cross_entropy(scores, labels)

            loss.backward()

            optimizer.step()

            optimizer.zero_grad()

            acc = (scores.max(1)[1] == labels).float().mean()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            #log the info
            if (i + 1) % log_iters == 0:

                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        lr = optimizer.state_dict()['param_groups'][0]['lr']
        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)
        # change the lr

        # eval the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) == epochs:

            model.eval()

            val_acc = RunningAverageMeter()
            with torch.no_grad():

                for vi, batch in enumerate(val_loader):

                    imgs, labels = batch

                    if use_gpu:
                        imgs = imgs.to(device)
                        labels = labels.to(device)

                    scores = model(imgs)
                    acc = (scores.max(1)[1] == labels).float().mean()
                    val_acc.update(acc)

                logger.info("validation results at epoch:{}".format(epoch + 1))
                logger.info("acc:{:.2%}".format(val_acc.avg))

                if val_acc.avg > best_acc:
                    logger.info("get a new best acc")
                    is_best = True

        # we hope that eval_period == ckpt_period or eval_period == k* ckpt_period where k is int
        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:

            if parallel:
                torch.save(
                    model.module.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            else:
                torch.save(
                    model.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))

            logger.info("checkpoint {} saved !".format(epoch + 1))

            if is_best:
                if parallel:
                    torch.save(model.module.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                else:
                    torch.save(model.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                logger.info("best checkpoint was saved")
                is_best = False

    logger.info("training is end, time for per imgs is {} ms".format(
        global_avg_time.avg * 1000))
Example #6
0
def train():

	# 1、make dataloader
	# prepare train,val img_info list, elem is tuple; 
	train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
	
	# 2、make model
	model = build_model(cfg, num_class)

	# 3、 make optimizer
	optimizer = make_optimizer(cfg, model)

	# 4、 make lr_scheduler
	scheduler = make_lr_scheduler(cfg, optimizer)

	# 5、make loss 
	loss_fn = make_loss(cfg, num_class)

	# get parameters 
	device = cfg.MODEL.DEVICE
	use_gpu = device == "cuda"
	pretrained = cfg.MODEL.PRETRAIN_PATH != ""
	parallel = cfg.MODEL.PARALLEL

	log_period = cfg.OUTPUT.LOG_PERIOD
	ckpt_period = cfg.OUTPUT.CKPT_PERIOD
	eval_period = cfg.OUTPUT.EVAL_PERIOD
	output_dir = cfg.OUTPUT.DIRS
	ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS
	
	epochs = cfg.SOLVER.MAX_EPOCHS
	batch_size = cfg.SOLVER.BATCH_SIZE
	grad_clip = cfg.SOLVER.GRAD_CLIP

	batch_num = len(train_loader)
	log_iters = batch_num // log_period 

	if not os.path.exists(ckpt_save_path):
		os.makedirs(ckpt_save_path)

	# create *_result.xlsx
	# save the result for analyze
	name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
	result_path = cfg.OUTPUT.DIRS + name

	wb = xl.Workbook()
	sheet = wb.worksheets[0]
	titles = ['size/M','speed/ms','final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss',
			  'acc', 'mAP', 'r1', 'r5', 'r10', 'loss','acc', 'mAP', 'r1', 'r5', 'r10', 'loss']
	sheet.append(titles)
	check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
	values = []

	logger = logging.getLogger("CDNet.train")
	size = count_parameters(model)
	values.append(format(size, '.2f'))
	values.append(model.final_planes)
	
	logger.info("the param number of the model is {:.2f} M".format(size))
	infer_size = infer_count_parameters(model)
	logger.info("the infer param number of the model is {:.2f}M".format(infer_size))

	shape = [1, 3]
	shape.extend(cfg.DATA.IMAGE_SIZE)
	
	# if cfg.MODEL.NAME == 'cdnet' :
	# 	infer_model = CDNetwork(num_class, cfg)
	# elif cfg.MODEL.NAME == 'cnet':
	# 	infer_model = CNetwork(num_class, cfg)
	# else:
	# 	infer_model = model 

	# for scaling experiment
	flops, _ = get_model_infos(model, shape)
	logger.info("the total flops number of the model is {:.2f} M".format(flops))
	
	logger.info("Starting Training CDNetwork")
	
	best_mAP, best_r1 = 0., 0.
	is_best = False
	avg_loss, avg_acc = RunningAverageMeter(),RunningAverageMeter()
	avg_time, global_avg_time = AverageMeter(), AverageMeter()

	if parallel:
		model = nn.DataParallel(model)
		
	if use_gpu:
		model = model.to(device)

	for epoch in range(epochs):
		
		scheduler.step()
		lr = scheduler.get_lr()[0]
		# if save epoch_num k, then run k+1 epoch next
		if pretrained and epoch < model.start_epoch:
			continue

		# rest the record
		model.train()
		avg_loss.reset()
		avg_acc.reset()
		avg_time.reset()

		for i, batch in enumerate(train_loader):

			t0 = time.time()
			imgs, labels = batch 

			if use_gpu:
				imgs = imgs.to(device)
				labels = labels.to(device)

			res = model(imgs)
		
			loss, acc = compute_loss_acc(res, labels, loss_fn)
			loss.backward()

			if grad_clip != 0:
				nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

			optimizer.step()
			optimizer.zero_grad()

			t1 = time.time()
			avg_time.update((t1 - t0) / batch_size)
			avg_loss.update(loss)
			avg_acc.update(acc)

			# log info
			if (i+1) % log_iters == 0:
				logger.info("epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".format(
					epoch+1, i+1, batch_num, avg_loss.avg, avg_acc.avg))

		logger.info("end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".format(epoch+1, epochs, lr, avg_time.avg * 1000))
		global_avg_time.update(avg_time.avg)

		# test the model
		if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

			model.eval()
			metrics = R1_mAP(num_query, use_gpu = use_gpu)

			with torch.no_grad():
				for vi, batch in enumerate(val_loader):
					
					imgs, labels, camids = batch
					if use_gpu:
						imgs = imgs.to(device)

					feats = model(imgs)
					metrics.update((feats, labels, camids))

				#compute cmc and mAP
				cmc, mAP = metrics.compute()
				logger.info("validation results at epoch {}".format(epoch + 1))
				logger.info("mAP:{:2%}".format(mAP))
				for r in [1,5,10]:
					logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r-1]))

				# determine whether current model is the best
				if mAP > best_mAP:
					is_best = True
					best_mAP = mAP
					logger.info("Get a new best mAP")
				if cmc[0] > best_r1:
					is_best = True
					best_r1 = cmc[0]
					logger.info("Get a new best r1")

				# add the result to sheet
				if (epoch + 1) in check_epochs:
					val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
					change = [format(v * 100, '.2f') for v in val]
					change.append(format(avg_loss.avg, '.3f'))
					values.extend(change)
					
		# whether to save the model
		if (epoch + 1) % ckpt_period == 0 or is_best:
			torch.save(model.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
			logger.info("checkpoint {} was saved".format(epoch + 1))

			if is_best:
				torch.save(model.state_dict(), ckpt_save_path + "best_ckpt.pth")
				logger.info("best_checkpoint was saved")
				is_best = False
		

	values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
	values.append(format(infer_size, '.2f'))
	sheet.append(values)
	wb.save(result_path)
	logger.info("best_mAP:{:.2%}, best_r1:{:.2%}".format(best_mAP, best_r1))
	logger.info("Ending training CDNetwork")