示例#1
0
def test():

    logger = logging.getLogger('CDNet.test')

    # prepare dataloader
    train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
    # prepare model
    model = build_model(cfg, num_class)

    infer_size = infer_count_parameters(model)
    logger.info(
        "the infer param number of the model is {:.2f}M".format(infer_size))

    shape = [1, 3]
    shape.extend(cfg.DATA.IMAGE_SIZE)
    flops, _ = get_model_infos(model, shape)
    logger.info("the total flops is: {:.2f} M".format(flops))

    # load param
    ckpt_path = cfg.OUTPUT.DIRS + cfg.OUTPUT.CKPT_DIRS + cfg.TEST.BEST_CKPT

    if os.path.isfile(ckpt_path):
        model.load_best_checkpoint(ckpt_path)
    else:
        logger.info("file: {} is not found".format(ckpt_path))
        exit(1)

    use_gpu = cfg.MODEL.DEVICE == 'cuda'
    if cfg.MODEL.PARALLEL:
        model = nn.DataParallel(model)
    if use_gpu:
        model = model.cuda()
    model.eval()
    metrics = R1_mAP(num_query, use_gpu=use_gpu)

    with torch.no_grad():
        begin = time.time()
        for batch in val_loader:
            imgs, pids, camids = batch

            if use_gpu:
                imgs = imgs.cuda()
            feats = model(imgs)
            metrics.update((feats, pids, camids))
        end1 = time.time()
        cmc, mAP = metrics.compute()
        end2 = time.time()
        logger.info("extract feature time is:{:.2f} s".format(end1 - begin))
        logger.info("match time is:{:.2f} s".format(end2 - end1))

        logger.info("test result as follows")
        logger.info("mAP:{:.2%}".format(mAP))
        for r in [1, 5, 10]:
            logger.info("CMC cure, Rank-{:<3}:{:.2%}".format(r, cmc[r - 1]))

        print("test is endding")
示例#2
0
def test():

    logger = logging.getLogger('MobileNetReID.test')

    # prepare dataloader
    train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
    # prepare model
    model = build_model(cfg, num_class)

    # load param
    ckpt_path = cfg.OUTPUT.ROOT_DIR + cfg.OUTPUT.CKPT_DIR + cfg.TEST.BEST_CKPT
    if os.path.isfile(ckpt_path):
        model.load_param(ckpt_path)
    else:
        logger.info("file: {} is not found".format(ckpt_path))
        exit(1)

    use_gpu = cfg.MODEL.DEVICE == 'cuda'
    device = cfg.MODEL.DEVICE_ID

    if use_gpu:
        model = nn.DataPararallel(model)
        model.to(device)

    model.eval()
    metrics = R1_mAP(num_query, use_gpu=use_gpu)

    with torch.no_grad():
        for batch in val_loader:
            data, pids, camids = batch

            if use_gpu:
                imgs.to(device)
            feats = model(imgs)
            metrics.update(feats, labels, camids)

        cmc, mAP = metrics.compute()
        logger.info("test result as follows")
        logger.info("mAP:{:2%}".format(mAP))
        for r in [1, 5, 10]:
            logger.info("CMC cure, Rank-{:<3}:{:2%}".format(r, cmc[r - 1]))

        print("test is endding")
示例#3
0
def train():
    """
	# get an image for test the model 
	train_transform = build_transforms(cfg, is_train = True)
	imgs = get_image("1.jpg")
	img_tensor = train_transform(imgs[0])
	# c,h,w = img_tensor.shape
	# img_tensor = img_tensor.view(-1,c,h,w)
	# add an axis
	img_tensor = img_tensor.unsqueeze(0)
	"""
    # 1、make dataloader
    train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
    #print("num_query:{},num_class:{}".format(num_query,num_class))

    # 2、make model
    model = build_model(cfg, num_class)

    # model.eval()
    # x = model(img_tensor)
    # print(x.shape)
    # 3、 make optimizer
    optimizer = make_optimizer(cfg, model)

    # 4、 make lr_scheduler
    scheduler = make_lr_scheduler(cfg, optimizer)

    # 5、 make loss_func
    if cfg.MODEL.PCB_NECK:
        # make loss specificially for pcb
        loss_func = get_softmax_triplet_loss_fn(cfg, num_class)
    else:
        loss_func = make_loss(cfg, num_class)

    # get paramters
    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CHECKPOINT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.ROOT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    use_gpu = device == "cuda"
    use_neck = cfg.MODEL.NECK or cfg.MODEL.LEARN_REGION
    # how many batch for each log
    batch_size = cfg.SOLVER.IMGS_PER_BATCH
    batch_num = len(train_loader)

    log_iters = batch_num // log_period
    pretrained = cfg.MODEL.PRETRAIN_PATH != ''
    parallel = cfg.MODEL.PARALLEL
    grad_clip = cfg.DARTS.GRAD_CLIP

    feat_norm = cfg.TEST.FEAT_NORM
    ckpt_save_path = cfg.OUTPUT.ROOT_DIR + cfg.OUTPUT.CKPT_DIR
    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.ROOT_DIR + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger('MobileNetReID.train')

    # count parameter
    size = count_parameters(model)
    logger.info("the param number of the model is {:.2f} M".format(size))

    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("Start training")

    #count = 183, x, y = batch -> 11712 for train
    if pretrained:
        start_epoch = model.start_epoch

    if parallel:
        model = nn.DataParallel(model)

    if use_gpu:
        # model = nn.DataParallel(model)
        model.to(device)

    # save the best model
    best_mAP, best_r1 = 0., 0.
    is_best = False
    # batch : img, pid, camid, img_path
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()
    global_avg_time.reset()
    for epoch in range(epochs):
        scheduler.step()

        if pretrained and epoch < start_epoch - 1:
            continue

        model.train()
        # sum_loss, sum_acc = 0., 0.
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()
        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)

            res = model(imgs)
            # score, feat = model(imgs)
            # loss = loss_func(score, feat, labels)
            loss, acc = compute_loss_acc(use_neck, res, labels, loss_func)

            loss.backward()
            if grad_clip != 0:
                nn.utils.clip_grad_norm(model.parameters(), grad_clip)

            optimizer.step()

            optimizer.zero_grad()

            # acc = (score.max(1)[1] == labels).float().mean()

            # sum_loss += loss
            # sum_acc += acc
            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            #log the info
            if (i + 1) % log_iters == 0:

                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        lr = optimizer.state_dict()['param_groups'][0]['lr']
        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)
        # change the lr

        # eval the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) == epochs:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu, feat_norm=feat_norm)

            with torch.no_grad():

                for vi, batch in enumerate(val_loader):

                    imgs, labels, camids = batch

                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch:{}".format(epoch + 1))
                logger.info("mAP:{:.2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether cur model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # we hope that eval_period == ckpt_period or eval_period == k* ckpt_period where k is int
        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:

            if parallel:
                torch.save(
                    model.module.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            else:
                torch.save(
                    model.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))

            logger.info("checkpoint {} saved !".format(epoch + 1))

            if is_best:
                if parallel:
                    torch.save(model.module.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                else:
                    torch.save(model.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                logger.info("best checkpoint was saved")
                is_best = False

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("training is end, time for per imgs is {} ms".format(
        global_avg_time.avg * 1000))
示例#4
0
def train():

    use_gpu = cfg.MODEL.DEVICE == "cuda"
    # 1、make dataloader
    train_loader, val_loader, test_loader, num_query, num_class = darts_make_data_loader(
        cfg)
    # print(num_query, num_class)

    # 2、make model
    model = CNetwork(num_class, cfg)
    # tensor = torch.randn(2, 3, 256, 128)
    # res = model(tensor)
    # print(res[0].size()) [2, 751]

    # 3、make optimizer
    optimizer = make_optimizer(cfg, model)
    # make architecture optimizer
    arch_optimizer = torch.optim.Adam(
        model._arch_parameters(),
        lr=cfg.SOLVER.ARCH_LR,
        betas=(0.5, 0.999),
        weight_decay=cfg.SOLVER.ARCH_WEIGHT_DECAY)

    # 4、make lr scheduler
    lr_scheduler = make_lr_scheduler(cfg, optimizer)
    # make lr scheduler
    arch_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        arch_optimizer, [80, 160], 0.1)

    # 5、make loss
    loss_fn = darts_make_loss(cfg)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAINED != ""

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CNet_Search.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("the param number of the model is {:.2f} M".format(size))

    logger.info("Starting Search CNetwork")

    best_mAP, best_r1 = 0., 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if use_gpu:
        model = model.to(device)

    if pretrained:
        logger.info("load self pretrained chekpoint to init")
        model.load_pretrained_model(cfg.MODEL.PRETRAINED)
    else:
        logger.info("use kaiming init to init the model")
        model.kaiming_init_()

    for epoch in range(epochs):

        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        # architect lr.step
        arch_lr_scheduler.step()

        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < model.start_epoch:
            continue

        # print(epoch)
        # exit(1)
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch
            val_imgs, val_labels = next(iter(val_loader))

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)
                val_imgs = val_imgs.to(device)
                val_labels = val_labels.to(device)

            # 1、 update the weights
            optimizer.zero_grad()
            res = model(imgs)

            # loss = loss_fn(scores, feats, labels)
            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            # 2、update the alpha
            arch_optimizer.zero_grad()
            res = model(val_imgs)

            val_loss, val_acc = compute_loss_acc(res, val_labels, loss_fn)
            val_loss.backward()
            arch_optimizer.step()

            # compute the acc
            # acc = (scores.max(1)[1] == labels).float().mean()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu)

            with torch.no_grad():
                for vi, batch in enumerate(test_loader):
                    # break
                    # print(len(batch))
                    imgs, labels, camids = batch
                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch {}".format(epoch + 1))
                logger.info("mAP:{:2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether current model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            torch.save(model.state_dict(),
                       ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            model._parse_genotype(file=ckpt_save_path +
                                  "genotype_{}.json".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                torch.save(model.state_dict(),
                           ckpt_save_path + "best_ckpt.pth")
                model._parse_genotype(file=ckpt_save_path +
                                      "best_genotype.json")
                logger.info("best_checkpoint was saved")
                is_best = False
        # exit(1)

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("Ending Search CNetwork")
示例#5
0
def train(cfg):
	
	use_gpu = cfg.device == 'cuda'
	# 1、make dataloader
	train_loader, val_loader, test_loader, num_query, num_class =  darts_make_data_loader(cfg)
	# print(num_query)

	# 2、make model
	if cfg.model_name == 'ssnet':
		model = SSNetwork(num_class, cfg, use_gpu)
	
	elif cfg.model_name == 'fsnet':
		model = FSNetwork(num_class, cfg.in_planes, cfg.init_size, cfg.layers, use_gpu, cfg.pretrained) 


	# 3、make optimizer
	optimizer = darts_make_optimizer(cfg, model)
	# print(optimizer)

	# 4、make lr scheduler
	lr_scheduler = darts_make_lr_scheduler(cfg, optimizer)
	# print(lr_scheduler)

	# 5、make loss 
	loss_func = darts_make_loss(cfg)
	model._set_loss(loss_func, compute_loss_acc)
	
	# 6、make architect
	architect = Architect(model, cfg)
	
	# get parameters
	log_period = cfg.log_period
	ckpt_period = cfg.ckpt_period
	eval_period = cfg.eval_period
	output_dir =  cfg.output_dir
	device = cfg.device 
	epochs = cfg.max_epochs
	ckpt_save_path = output_dir + cfg.ckpt_dir 

	use_gpu = device == 'cuda'
	batch_size = cfg.batch_size
	batch_num = len(train_loader)
	log_iters = batch_num // log_period 
	pretrained = cfg.pretrained is not None
	parallel = False
	use_neck = cfg.use_neck 

	if not os.path.exists(ckpt_save_path):
		os.makedirs(ckpt_save_path)

	logger = logging.getLogger("DARTS.train")
	size = count_parameters(model)
	logger.info("the param number of the model is {:.2f} M".format(size))

	logger.info("Start training")
	
	
	if pretrained:
		start_epoch = model.start_epoch 
	if parallel:
		model = nn.DataParallel(model)
	if use_gpu:
		model = model.to(device)

	best_mAP, best_r1 = 0., 0.
	is_best = False
	avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
	avg_time = AverageMeter()
	# num = 3 -> epoch = 2
	for epoch in range(epochs):
		lr_scheduler.step()
		lr = lr_scheduler.get_lr()[0]
		# architect lr.step
		architect.lr_scheduler.step()
		
		if pretrained and epoch < model.start_epoch :
			continue

		model.train()
		avg_loss.reset()
		avg_acc.reset()
		avg_time.reset()

		for i, batch in enumerate(train_loader):
			
			t0 = time.time()
			imgs, labels = batch
			val_imgs, val_labels = next(iter(val_loader))
			
			if use_gpu:
				imgs = imgs.to(device)
				labels = labels.to(device)
				val_imgs = val_imgs.to(device)
				val_labels = val_labels.to(device)

			# 1、update alpha
			architect.step(imgs, labels, val_imgs, val_labels, lr, optimizer, unrolled = cfg.unrolled)

			optimizer.zero_grad()
			res = model(imgs)
			# loss = loss_func(score, feats, labels)
			loss, acc = compute_loss_acc(use_neck, res, labels, loss_func)
			# print("loss:",loss.item())

			loss.backward()
			nn.utils.clip_grad_norm(model.parameters(), cfg.grad_clip)
			
			# 2、update weights
			optimizer.step()

			# acc = (score.max(1)[1] == labels).float().mean()
			# print("acc:", acc)
			t1 = time.time()
			avg_time.update((t1 - t0) / batch_size)
			avg_loss.update(loss)
			avg_acc.update(acc)
			

			# log info
			if (i+1) % log_iters == 0:
				logger.info("epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".format(
					epoch+1, i+1, batch_num, avg_loss.avg, avg_acc.avg))

		logger.info("end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".format(epoch+1, epochs, lr, avg_time.avg * 1000))

		
		# test the model
		if (epoch + 1) % eval_period == 0:
			
			model.eval()
			metrics = R1_mAP(num_query, use_gpu = use_gpu)

			with torch.no_grad():

				for vi, batch in enumerate(test_loader):

					imgs, labels, camids = batch

					if use_gpu:
						imgs = imgs.to(device)

					feats = model(imgs)
					metrics.update((feats, labels, camids))

				# compute cmc and mAP
				cmc, mAP = metrics.compute()
				logger.info("validation results at epoch {}".format(epoch + 1))
				logger.info("mAP:{:2%}".format(mAP))
				for r in [1,5,10]:
					logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r-1]))

				# determine whether current model is the best
				if mAP > best_mAP:
					is_best = True
					best_mAP = mAP
					logger.info("Get a new best mAP")
				if cmc[0] > best_r1:
					is_best = True
					best_r1 = cmc[0]
					logger.info("Get a new best r1")

		# whether to save the model
		if (epoch + 1) % ckpt_period == 0 or is_best:

			if parallel:
				torch.save(model.module.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
				model.module._parse_genotype(file = ckpt_save_path + "genotype_{}.json".format(epoch + 1))
			else:
				torch.save(model.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
				model._parse_genotype(file = ckpt_save_path + "genotype_{}.json".format(epoch + 1))
			
			logger.info("checkpoint {} was saved".format(epoch + 1))

			if is_best:
				if parallel:
					torch.save(model.module.state_dict(), ckpt_save_path + "best_ckpt.pth")
					model.module._parse_genotype(file = ckpt_save_path + "best_genotype.json")
				else:
					torch.save(model.state_dict(), ckpt_save_path + "best_ckpt.pth")
					model._parse_genotype(file = ckpt_save_path + "best_genotype.json")

				logger.info("best_checkpoint was saved")
				is_best = False
		

	logger.info("training is end")
示例#6
0
文件: train.py 项目: solicucu/ReID
def train():

	# 1、make dataloader
	# prepare train,val img_info list, elem is tuple; 
	train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
	
	# 2、make model
	model = build_model(cfg, num_class)

	# 3、 make optimizer
	optimizer = make_optimizer(cfg, model)

	# 4、 make lr_scheduler
	scheduler = make_lr_scheduler(cfg, optimizer)

	# 5、make loss 
	loss_fn = make_loss(cfg, num_class)

	# get parameters 
	device = cfg.MODEL.DEVICE
	use_gpu = device == "cuda"
	pretrained = cfg.MODEL.PRETRAIN_PATH != ""
	parallel = cfg.MODEL.PARALLEL

	log_period = cfg.OUTPUT.LOG_PERIOD
	ckpt_period = cfg.OUTPUT.CKPT_PERIOD
	eval_period = cfg.OUTPUT.EVAL_PERIOD
	output_dir = cfg.OUTPUT.DIRS
	ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS
	
	epochs = cfg.SOLVER.MAX_EPOCHS
	batch_size = cfg.SOLVER.BATCH_SIZE
	grad_clip = cfg.SOLVER.GRAD_CLIP

	batch_num = len(train_loader)
	log_iters = batch_num // log_period 

	if not os.path.exists(ckpt_save_path):
		os.makedirs(ckpt_save_path)

	# create *_result.xlsx
	# save the result for analyze
	name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
	result_path = cfg.OUTPUT.DIRS + name

	wb = xl.Workbook()
	sheet = wb.worksheets[0]
	titles = ['size/M','speed/ms','final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss',
			  'acc', 'mAP', 'r1', 'r5', 'r10', 'loss','acc', 'mAP', 'r1', 'r5', 'r10', 'loss']
	sheet.append(titles)
	check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
	values = []

	logger = logging.getLogger("CDNet.train")
	size = count_parameters(model)
	values.append(format(size, '.2f'))
	values.append(model.final_planes)
	
	logger.info("the param number of the model is {:.2f} M".format(size))
	infer_size = infer_count_parameters(model)
	logger.info("the infer param number of the model is {:.2f}M".format(infer_size))

	shape = [1, 3]
	shape.extend(cfg.DATA.IMAGE_SIZE)
	
	# if cfg.MODEL.NAME == 'cdnet' :
	# 	infer_model = CDNetwork(num_class, cfg)
	# elif cfg.MODEL.NAME == 'cnet':
	# 	infer_model = CNetwork(num_class, cfg)
	# else:
	# 	infer_model = model 

	# for scaling experiment
	flops, _ = get_model_infos(model, shape)
	logger.info("the total flops number of the model is {:.2f} M".format(flops))
	
	logger.info("Starting Training CDNetwork")
	
	best_mAP, best_r1 = 0., 0.
	is_best = False
	avg_loss, avg_acc = RunningAverageMeter(),RunningAverageMeter()
	avg_time, global_avg_time = AverageMeter(), AverageMeter()

	if parallel:
		model = nn.DataParallel(model)
		
	if use_gpu:
		model = model.to(device)

	for epoch in range(epochs):
		
		scheduler.step()
		lr = scheduler.get_lr()[0]
		# if save epoch_num k, then run k+1 epoch next
		if pretrained and epoch < model.start_epoch:
			continue

		# rest the record
		model.train()
		avg_loss.reset()
		avg_acc.reset()
		avg_time.reset()

		for i, batch in enumerate(train_loader):

			t0 = time.time()
			imgs, labels = batch 

			if use_gpu:
				imgs = imgs.to(device)
				labels = labels.to(device)

			res = model(imgs)
		
			loss, acc = compute_loss_acc(res, labels, loss_fn)
			loss.backward()

			if grad_clip != 0:
				nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

			optimizer.step()
			optimizer.zero_grad()

			t1 = time.time()
			avg_time.update((t1 - t0) / batch_size)
			avg_loss.update(loss)
			avg_acc.update(acc)

			# log info
			if (i+1) % log_iters == 0:
				logger.info("epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".format(
					epoch+1, i+1, batch_num, avg_loss.avg, avg_acc.avg))

		logger.info("end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".format(epoch+1, epochs, lr, avg_time.avg * 1000))
		global_avg_time.update(avg_time.avg)

		# test the model
		if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

			model.eval()
			metrics = R1_mAP(num_query, use_gpu = use_gpu)

			with torch.no_grad():
				for vi, batch in enumerate(val_loader):
					
					imgs, labels, camids = batch
					if use_gpu:
						imgs = imgs.to(device)

					feats = model(imgs)
					metrics.update((feats, labels, camids))

				#compute cmc and mAP
				cmc, mAP = metrics.compute()
				logger.info("validation results at epoch {}".format(epoch + 1))
				logger.info("mAP:{:2%}".format(mAP))
				for r in [1,5,10]:
					logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r-1]))

				# determine whether current model is the best
				if mAP > best_mAP:
					is_best = True
					best_mAP = mAP
					logger.info("Get a new best mAP")
				if cmc[0] > best_r1:
					is_best = True
					best_r1 = cmc[0]
					logger.info("Get a new best r1")

				# add the result to sheet
				if (epoch + 1) in check_epochs:
					val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
					change = [format(v * 100, '.2f') for v in val]
					change.append(format(avg_loss.avg, '.3f'))
					values.extend(change)
					
		# whether to save the model
		if (epoch + 1) % ckpt_period == 0 or is_best:
			torch.save(model.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
			logger.info("checkpoint {} was saved".format(epoch + 1))

			if is_best:
				torch.save(model.state_dict(), ckpt_save_path + "best_ckpt.pth")
				logger.info("best_checkpoint was saved")
				is_best = False
		

	values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
	values.append(format(infer_size, '.2f'))
	sheet.append(values)
	wb.save(result_path)
	logger.info("best_mAP:{:.2%}, best_r1:{:.2%}".format(best_mAP, best_r1))
	logger.info("Ending training CDNetwork")