示例#1
0
def train_an_epoch(config, base, loaders):

    base.set_train()
    meter = MultiItemAverageMeter()

    ### we assume 200 iterations as an epoch
    for _ in range(200):

        ### load a batch data
        imgs, pids, _ = loaders.train_iter.next_one()
        imgs, pids = imgs.to(base.device), pids.to(base.device)

        ### forward
        logits_list, embeddings_list = base.model(imgs)

        ### loss
        ide_loss, avg_logits = base.compute_ide_loss(logits_list, pids)
        source_acc = accuracy(avg_logits, pids, [1])[0]

        ### optimize
        base.optimizer.zero_grad()
        ide_loss.backward()
        base.optimizer.step()

        ### recored
        meter.update({'ide_loss': ide_loss, 'acc': source_acc})

    return meter.get_val(), meter.get_str()
示例#2
0
def train_stage1(config, base, loaders):

    base.set_train()
    source_loader = loaders.source_loader
    meter = MultiItemAverageMeter()
    for i in range(1032):
        source_imgs, source_pids, source_cids = source_loader.next_one()
        source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \
                                                source_cids.to(base.device)

        source_features = base.feature_extractor(source_imgs)
        pid_cls_score = base.identity_classifier(source_features)
        pid_domain_cls_score = base.identitydomain_classifier(source_features)
        pid_loss = base.source_pid_creiteron(pid_cls_score, source_pids)
        pid_domain_loss = base.source_identity_domain_creiteron(
            pid_domain_cls_score, source_pids)

        loss = pid_loss + pid_domain_loss

        base.feature_extractor_optimizer.zero_grad()
        base.identity_classifier_optimizer.zero_grad()
        base.identitydomain_classifier_optimizer.zero_grad()

        loss.backward()

        base.feature_extractor_optimizer.step()
        base.identity_classifier_optimizer.step()
        base.identitydomain_classifier_optimizer.step()

        meter.update({
            'pid_loss': pid_loss.data,
            'pid_domain_loss': pid_domain_loss.data
        })

    return meter.get_val(), meter.get_str()
示例#3
0
def train_an_epoch(config, base, loaders):

    base.set_train()
    meter = MultiItemAverageMeter()

    ### we assume 200 iterations as an epoch
    for _ in range(200):

        ### load a batch data
        imgs, pids, _ = loaders.train_iter.next_one()
        imgs, pids = imgs.to(base.device), pids.to(base.device)

        ### forward
        features, cls_score = base.model(imgs)

        ### loss
        ide_loss = base.ide_creiteron(cls_score, pids)
        triplet_loss = base.triplet_creiteron(features, features, features,
                                              pids, pids, pids)
        loss = ide_loss + triplet_loss
        acc = accuracy(cls_score, pids, [1])[0]

        ### optimize
        base.optimizer.zero_grad()
        loss.backward()
        base.optimizer.step()

        ### recored
        meter.update({
            'ide_loss': ide_loss.data,
            'triplet_loss': triplet_loss.data,
            'acc': acc
        })

    return meter.get_val(), meter.get_str()
def train_an_epoch(config, base, loaders, epoch=None):

	base.set_train()
	meter = MultiItemAverageMeter()

	### we assume 200 iterations as an epoch
	base.lr_scheduler.step(epoch)
	for _ in range(config.steps):

		### load a batch data
		imgs, pids, _ = loaders.train_iter.next_one()
		imgs, pids = imgs.to(base.device), pids.to(base.device)

		if 'res' in config.cnnbackbone:
			### forward
			features, cls_score = base.model(imgs)
			### loss
			ide_loss = base.ide_creiteron(cls_score, pids)
			triplet_loss = base.triplet_creiteron(features, features, features, pids, pids, pids)
			loss = ide_loss + triplet_loss
			acc = accuracy(cls_score, pids, [1])[0]
			### optimize
			base.optimizer.zero_grad()
			loss.backward()
			base.optimizer.step()
			### recored
			meter.update({'ide_loss': ide_loss.data, 'triplet_loss': triplet_loss.data, 'acc': acc})
		elif config.cnnbackbone == 'osnetain':
			### forward
			if epoch < 10:
				cls_score = base.model(imgs, fixed_cnn=True)
			else:
				cls_score = base.model(imgs, fixed_cnn=False)
			### loss
			ide_loss = base.ide_creiteron(cls_score, pids)
			acc = accuracy(cls_score, pids, [1])[0]
			### optimize
			base.optimizer.zero_grad()
			ide_loss.backward()
			base.optimizer.step()
			### recored
			meter.update({'ide_loss': ide_loss.data, 'acc': acc})

	return meter.get_val(), meter.get_str()
示例#5
0
文件: train.py 项目: lhf12278/KG-FGIN
def train_multi_view(config, base, loaders):

    base.set_train()
    target_loader = loaders.target_loader
    meter = MultiItemAverageMeter()
    for i in range(100):
        target_imgs, target_pids, target_cids = IterLoader(
            target_loader).next_one()
        target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \
                                                target_cids.to(base.device)
        target_features, target_bn_features, target_local_features, target_bn_local_features = \
            base.feature_extractor(target_imgs)
        target_graph_global_features = base.graph(target_local_features,
                                                  target_features)
        target_cls_score = base.classifier2(target_bn_features)
        target_local_cls_score = base.local_classifier(
            target_bn_local_features)
        target_graph_cls_score = base.classifier3(target_graph_global_features)
        target_ide_loss = base.target_ide_creiteron(target_cls_score,
                                                    target_pids)
        target_local_ide_loss = base.compute_local_pid_loss(
            target_local_cls_score, target_pids)
        target_graph_ide_loss = base.target_ide_creiteron(
            target_graph_cls_score, target_pids)

        total_loss = target_ide_loss + target_local_ide_loss + config.lambda1 * target_graph_ide_loss

        base.feature_extractor_optimizer.zero_grad()
        base.graph_optimizer.zero_grad()
        base.classifier2_optimizer.zero_grad()
        base.local_classifier_optimizer.zero_grad()
        base.classifier3_optimizer.zero_grad()
        total_loss.backward()
        base.feature_extractor_optimizer.step()
        base.graph_optimizer.step()
        base.classifier2_optimizer.step()
        base.local_classifier_optimizer.step()
        base.classifier3_optimizer.step()
        meter.update({
            'target_ide_loss': target_ide_loss.data,
            'target_local_ide_loss': target_local_ide_loss.data,
            'target_graph_ide_loss': target_graph_ide_loss.data
        })

    return meter.get_val(), meter.get_str()
示例#6
0
def train_an_epoch(config,
                   iter_n,
                   loaders,
                   base,
                   current_epoch,
                   train_gan,
                   train_reid,
                   self_training,
                   train_adaptation,
                   optimize_sl_enc,
                   trg_labeled_loader=None):

    # set train mode
    base.set_train()
    base.lr_scheduler_step(current_epoch)
    meter = MultiItemAverageMeter()

    # train loop
    for _ in range(5):

        # zero grad
        base.sl_enc_optimizer.zero_grad()
        base.gen_optimizer.zero_grad()
        base.dis_optimizer.zero_grad()
        base.il_enc_optimizer.zero_grad()

        #
        results = {}
        # gan
        if train_gan:
            gen_loss_without_feature, gen_loss_gan_feature, dis_loss, image_list = train_gan_an_iter(
                config, loaders, base)

            gen_loss_gan_feature.backward(retain_graph=True)
            base.sl_enc_optimizer.zero_grad()
            gen_loss_without_feature.backward()

            base.dis_optimizer.zero_grad()
            dis_loss.backward()

            results['gen_loss_gan_feature'] = gen_loss_gan_feature.item()
            results[
                'gen_loss_without_feature'] = gen_loss_without_feature.item()
            results['dis_loss'] = dis_loss.item()

        if train_reid:
            cls_loss, triplet_loss, acc = train_reid_an_iter(
                config, loaders, base, True)
            reid_loss = cls_loss + triplet_loss
            reid_loss.backward()
            results['cls_loss'] = cls_loss.item()
            results['triplet_loss'] = triplet_loss.item()
            results['acc'] = acc

        if self_training:
            assert trg_labeled_loader is not None, "self training requires labeled target loader"
            cls_loss, triplet_loss, acc = train_reid_an_iter(
                config, loaders, base, False, trg_labeled_loader)
            reid_loss = cls_loss + triplet_loss
            reid_loss.backward()
            results['cls_loss'] = cls_loss.item()
            results['triplet_loss'] = triplet_loss.item()
            results['acc'] = acc

        if train_adaptation:
            gen_loss_real, gen_loss_fake, gen_loss_cycle, dis_loss_real, dis_loss_fake, dis_loss_cycle = \
                                         train_adaptation_an_iter(config,loaders,base)
            gen_loss = gen_loss_real + gen_loss_fake + gen_loss_cycle
            dis_loss = dis_loss_real + dis_loss_fake + dis_loss_cycle

            base.dom_gen_optimizer.zero_grad()
            gen_loss.backward()

            base.dom_dis_optimizer.zero_grad()
            dis_loss.backward()
            results['dom_gen_loss'] = gen_loss.item()
            results['dom_dis_loss'] = dis_loss.item()

        # optimize
        if optimize_sl_enc:
            base.sl_enc_optimizer.step()
        if train_gan:
            base.gen_optimizer.step()
            base.dis_optimizer.step()
        if train_reid or self_training:
            base.il_enc_optimizer.step()
        if train_adaptation:
            base.dom_gen_optimizer.step()
            base.dom_dis_optimizer.step()
        # record
        print("a new iter")
        meter.update(results)
    return meter.get_str()
示例#7
0
def train_stage2(config, base, loaders):

    base.set_train()
    source_loader = loaders.source_loader
    target_loader = loaders.target_loader
    meter = MultiItemAverageMeter()
    for i in range(808):
        source_imgs, source_pids, source_cids = source_loader.next_one()
        source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \
                                                source_cids.to(base.device)
        target_imgs, target_pids, target_cids = target_loader.next_one()
        target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \
                                                target_cids.to(base.device)

        source_features = base.feature_extractor(source_imgs)
        target_features = base.feature_extractor(target_imgs)
        source_pid_cls_score = base.identity_classifier(source_features)
        source_cid_cls_score = base.camera_classifier(source_features)
        multi_view_features = MultiViewReasoning().__call__(
            source_features, source_features, source_pids, source_pids,
            source_pid_cls_score)
        multi_view_cls_score = base.identity_classifier(multi_view_features)
        source_pid_domain_cls_score = base.identitydomain_classifier(
            source_features)
        target_cid_cls_score = base.camera_classifier(target_features)
        target_pid_domain_cls_score = base.identitydomain_classifier(
            target_features)
        source_pid_loss = base.source_pid_creiteron(source_pid_cls_score,
                                                    source_pids)
        multi_view_pid_loss = base.source_pid_creiteron(
            multi_view_cls_score, source_pids)
        source_pid_domain_loss = base.source_identity_domain_creiteron(
            source_pid_domain_cls_score, source_pids)
        target_pid_domain_loss = base.target_identity_domain_creiteron(
            target_pid_domain_cls_score)
        source_dispency_loss = base.dispency_creiteron(
            source_pid_cls_score, source_pid_domain_cls_score)
        cid_classifier_loss = base.camera_classifier_creiteron(
            source_cid_cls_score, target_cid_cls_score, source_cids,
            target_cids)

        loss = source_pid_loss + multi_view_pid_loss + source_pid_domain_loss + config.lambda1 * target_pid_domain_loss\
               - config.lambda2 * source_dispency_loss + config.lambda3 * cid_classifier_loss

        base.feature_extractor_optimizer.zero_grad()
        base.identity_classifier_optimizer.zero_grad()
        base.identitydomain_classifier_optimizer.zero_grad()
        base.camera_classifier_optimizer.zero_grad()

        loss.backward()

        base.identity_classifier_optimizer.step()
        base.identitydomain_classifier_optimizer.step()
        base.camera_classifier_optimizer.step()

        source_features = base.feature_extractor(source_imgs)
        target_features = base.feature_extractor(target_imgs)
        source_pid_cls_score = base.identity_classifier(source_features)
        source_cid_cls_score = base.camera_classifier(source_features)
        multi_view_features = MultiViewReasoning().__call__(
            source_features, source_features, source_pids, source_pids,
            source_pid_cls_score)
        multi_view_cls_score = base.identity_classifier(multi_view_features)
        source_pid_domain_cls_score = base.identitydomain_classifier(
            source_features)
        target_cid_cls_score = base.camera_classifier(target_features)
        target_pid_domain_cls_score = base.identitydomain_classifier(
            target_features)
        source_pid_loss = base.source_pid_creiteron(source_pid_cls_score,
                                                    source_pids)
        multi_view_pid_loss = base.source_pid_creiteron(
            multi_view_cls_score, source_pids)
        single_multi_view_feature_loss = base.single_multi_creiteron(
            source_features, multi_view_features)
        source_pid_domain_loss = base.source_identity_domain_creiteron(
            source_pid_domain_cls_score, source_pids)
        target_pid_domain_loss = base.target_identity_domain_creiteron(
            target_pid_domain_cls_score)
        source_dispency_loss = base.dispency_creiteron(
            source_pid_cls_score, source_pid_domain_cls_score)
        cid_feature_extractor_loss = base.camera_feature_extractor_creiteron(
            source_cid_cls_score, target_cid_cls_score)


        loss = source_pid_loss + multi_view_pid_loss + source_pid_domain_loss + \
               config.beta * single_multi_view_feature_loss + config.lambda1 * target_pid_domain_loss \
               + config.lambda2 * source_dispency_loss + config.lambda3 * cid_feature_extractor_loss
        base.feature_extractor_optimizer.zero_grad()
        base.identity_classifier_optimizer.zero_grad()
        base.identitydomain_classifier_optimizer.zero_grad()
        base.camera_classifier_optimizer.zero_grad()

        loss.backward()

        base.feature_extractor_optimizer.step()

        meter.update({
            'source_pid_loss':
            source_pid_loss.data,
            'multi_view_pid_loss':
            multi_view_pid_loss.data,
            'single_multi_view_feature_loss':
            single_multi_view_feature_loss.data,
            'source_pid_domain_loss':
            source_pid_domain_loss.data,
            'target_pid_domain_loss':
            target_pid_domain_loss.data,
            'source_dispency_loss':
            source_dispency_loss.data,
            'cid_classifier_loss':
            cid_classifier_loss.data,
            'cid_feature_extractor_loss':
            cid_feature_extractor_loss
        })

    return meter.get_val(), meter.get_str()
示例#8
0
文件: train.py 项目: lhf12278/KG-FGIN
def train_meta_learning(base, loaders):

    base.set_train()
    source_loader = loaders.source_loader
    target_loader = loaders.target_loader
    meter = MultiItemAverageMeter()
    for i in range(400):
        source_imgs, source_pids, source_cids = IterLoader(
            source_loader).next_one()
        source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \
                                                source_cids.to(base.device)
        source_features, source_bn_features, source_local_features, source_bn_local_features = base.feature_extractor(
            source_imgs)
        source_cls_score = base.classifier1(source_bn_features)
        source_ide_loss = base.source_ide_creiteron(source_cls_score,
                                                    source_pids)
        grad = torch.autograd.grad(
            source_ide_loss,
            nn.ModuleList([base.feature_extractor,
                           base.classifier1]).parameters(),
            retain_graph=True)
        fast_weights = list(
            map(
                lambda p: p[1] - 0.0005 * p[0],
                zip(
                    grad,
                    nn.ModuleList([base.feature_extractor,
                                   base.classifier1]).parameters())))
        feature_extractor_param_keys = []
        for k, v in base.feature_extractor.state_dict().items():
            if k.split('.')[-1] == 'weight' or k.split('.')[-1] == 'bias':
                feature_extractor_param_keys.append(k)
        feature_extractor_param_dict = OrderedDict()
        for i in range(len(fast_weights) - 1):
            feature_extractor_param_dict[
                feature_extractor_param_keys[i]] = fast_weights[i]
        base.feature_extractor.state_dict().update(
            feature_extractor_param_dict)

        target_imgs, target_pids, target_cids = IterLoader(
            target_loader).next_one()
        target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \
                                                target_cids.to(base.device)
        target_features, target_bn_features, _, _ = base.feature_extractor(
            target_imgs)
        target_cls_score = base.classifier2(target_bn_features)
        target_ide_loss = base.target_ide_creiteron(target_cls_score,
                                                    target_pids)
        total_loss = source_ide_loss + target_ide_loss
        base.feature_extractor_optimizer.zero_grad()
        base.classifier1_optimizer.zero_grad()
        base.classifier2_optimizer.zero_grad()
        total_loss.backward()
        base.feature_extractor_optimizer.step()
        base.classifier1_optimizer.step()
        base.classifier2_optimizer.step()
        meter.update({
            'source_ide_loss': source_ide_loss.data,
            'target_ide_loss': target_ide_loss.data
        })

    return meter.get_val(), meter.get_str()
示例#9
0
def train_an_epoch(config, loaders, base, current_epoch, train_gan, train_reid,
                   train_pixel, optimize_sl_enc):

    # set train mode
    base.set_train()
    base.lr_scheduler_step(current_epoch)
    meter = MultiItemAverageMeter()

    # train loop
    for _ in range(200):

        # zero grad
        base.sl_enc_optimizer.zero_grad()
        base.gen_optimizer.zero_grad()
        base.dis_optimizer.zero_grad()
        base.il_enc_optimizer.zero_grad()

        #
        results = {}

        # gan
        if train_gan:
            gen_loss_without_feature, gen_loss_gan_feature, dis_loss, image_list = train_gan_an_iter(
                config, loaders, base)

            gen_loss_gan_feature.backward(retain_graph=True)
            base.sl_enc_optimizer.zero_grad()
            gen_loss_without_feature.backward()

            base.dis_optimizer.zero_grad()
            dis_loss.backward()

            results['gen_loss_gan_feature'] = gen_loss_gan_feature.data
            results['gen_loss_without_feature'] = gen_loss_without_feature.data
            results['dis_loss'] = dis_loss.data

        # pixel
        if train_pixel:
            assert train_gan
            pixel_loss = train_pixel_an_iter(config, loaders, base, image_list)
            (config.weight_pixel_loss * pixel_loss).backward()
            results['pixel_loss'] = pixel_loss.data

        # reid
        if train_reid:
            cls_loss, triplet_loss, acc = train_reid_an_iter(
                config, loaders, base)
            reid_loss = cls_loss + triplet_loss
            reid_loss.backward()
            results['cls_loss'] = cls_loss.data
            results['triplet_loss'] = triplet_loss.data
            results['acc'] = acc

        # optimize
        if optimize_sl_enc:
            base.sl_enc_optimizer.step()
        if train_gan:
            base.gen_optimizer.step()
            base.dis_optimizer.step()
        if train_reid:
            base.il_enc_optimizer.step()

        # record
        meter.update(results)

    return meter.get_str()