def train_an_epoch(config, base, loaders): base.set_train() meter = MultiItemAverageMeter() ### we assume 200 iterations as an epoch for _ in range(200): ### load a batch data imgs, pids, _ = loaders.train_iter.next_one() imgs, pids = imgs.to(base.device), pids.to(base.device) ### forward logits_list, embeddings_list = base.model(imgs) ### loss ide_loss, avg_logits = base.compute_ide_loss(logits_list, pids) source_acc = accuracy(avg_logits, pids, [1])[0] ### optimize base.optimizer.zero_grad() ide_loss.backward() base.optimizer.step() ### recored meter.update({'ide_loss': ide_loss, 'acc': source_acc}) return meter.get_val(), meter.get_str()
def train_stage1(config, base, loaders): base.set_train() source_loader = loaders.source_loader meter = MultiItemAverageMeter() for i in range(1032): source_imgs, source_pids, source_cids = source_loader.next_one() source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \ source_cids.to(base.device) source_features = base.feature_extractor(source_imgs) pid_cls_score = base.identity_classifier(source_features) pid_domain_cls_score = base.identitydomain_classifier(source_features) pid_loss = base.source_pid_creiteron(pid_cls_score, source_pids) pid_domain_loss = base.source_identity_domain_creiteron( pid_domain_cls_score, source_pids) loss = pid_loss + pid_domain_loss base.feature_extractor_optimizer.zero_grad() base.identity_classifier_optimizer.zero_grad() base.identitydomain_classifier_optimizer.zero_grad() loss.backward() base.feature_extractor_optimizer.step() base.identity_classifier_optimizer.step() base.identitydomain_classifier_optimizer.step() meter.update({ 'pid_loss': pid_loss.data, 'pid_domain_loss': pid_domain_loss.data }) return meter.get_val(), meter.get_str()
def train_an_epoch(config, base, loaders): base.set_train() meter = MultiItemAverageMeter() ### we assume 200 iterations as an epoch for _ in range(200): ### load a batch data imgs, pids, _ = loaders.train_iter.next_one() imgs, pids = imgs.to(base.device), pids.to(base.device) ### forward features, cls_score = base.model(imgs) ### loss ide_loss = base.ide_creiteron(cls_score, pids) triplet_loss = base.triplet_creiteron(features, features, features, pids, pids, pids) loss = ide_loss + triplet_loss acc = accuracy(cls_score, pids, [1])[0] ### optimize base.optimizer.zero_grad() loss.backward() base.optimizer.step() ### recored meter.update({ 'ide_loss': ide_loss.data, 'triplet_loss': triplet_loss.data, 'acc': acc }) return meter.get_val(), meter.get_str()
def train_multi_view(config, base, loaders): base.set_train() target_loader = loaders.target_loader meter = MultiItemAverageMeter() for i in range(100): target_imgs, target_pids, target_cids = IterLoader( target_loader).next_one() target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \ target_cids.to(base.device) target_features, target_bn_features, target_local_features, target_bn_local_features = \ base.feature_extractor(target_imgs) target_graph_global_features = base.graph(target_local_features, target_features) target_cls_score = base.classifier2(target_bn_features) target_local_cls_score = base.local_classifier( target_bn_local_features) target_graph_cls_score = base.classifier3(target_graph_global_features) target_ide_loss = base.target_ide_creiteron(target_cls_score, target_pids) target_local_ide_loss = base.compute_local_pid_loss( target_local_cls_score, target_pids) target_graph_ide_loss = base.target_ide_creiteron( target_graph_cls_score, target_pids) total_loss = target_ide_loss + target_local_ide_loss + config.lambda1 * target_graph_ide_loss base.feature_extractor_optimizer.zero_grad() base.graph_optimizer.zero_grad() base.classifier2_optimizer.zero_grad() base.local_classifier_optimizer.zero_grad() base.classifier3_optimizer.zero_grad() total_loss.backward() base.feature_extractor_optimizer.step() base.graph_optimizer.step() base.classifier2_optimizer.step() base.local_classifier_optimizer.step() base.classifier3_optimizer.step() meter.update({ 'target_ide_loss': target_ide_loss.data, 'target_local_ide_loss': target_local_ide_loss.data, 'target_graph_ide_loss': target_graph_ide_loss.data }) return meter.get_val(), meter.get_str()
def train_an_epoch(config, base, loaders, epoch=None): base.set_train() meter = MultiItemAverageMeter() ### we assume 200 iterations as an epoch base.lr_scheduler.step(epoch) for _ in range(config.steps): ### load a batch data imgs, pids, _ = loaders.train_iter.next_one() imgs, pids = imgs.to(base.device), pids.to(base.device) if 'res' in config.cnnbackbone: ### forward features, cls_score = base.model(imgs) ### loss ide_loss = base.ide_creiteron(cls_score, pids) triplet_loss = base.triplet_creiteron(features, features, features, pids, pids, pids) loss = ide_loss + triplet_loss acc = accuracy(cls_score, pids, [1])[0] ### optimize base.optimizer.zero_grad() loss.backward() base.optimizer.step() ### recored meter.update({'ide_loss': ide_loss.data, 'triplet_loss': triplet_loss.data, 'acc': acc}) elif config.cnnbackbone == 'osnetain': ### forward if epoch < 10: cls_score = base.model(imgs, fixed_cnn=True) else: cls_score = base.model(imgs, fixed_cnn=False) ### loss ide_loss = base.ide_creiteron(cls_score, pids) acc = accuracy(cls_score, pids, [1])[0] ### optimize base.optimizer.zero_grad() ide_loss.backward() base.optimizer.step() ### recored meter.update({'ide_loss': ide_loss.data, 'acc': acc}) return meter.get_val(), meter.get_str()
def train_an_epoch(config, iter_n, loaders, base, current_epoch, train_gan, train_reid, self_training, train_adaptation, optimize_sl_enc, trg_labeled_loader=None): # set train mode base.set_train() base.lr_scheduler_step(current_epoch) meter = MultiItemAverageMeter() # train loop for _ in range(5): # zero grad base.sl_enc_optimizer.zero_grad() base.gen_optimizer.zero_grad() base.dis_optimizer.zero_grad() base.il_enc_optimizer.zero_grad() # results = {} # gan if train_gan: gen_loss_without_feature, gen_loss_gan_feature, dis_loss, image_list = train_gan_an_iter( config, loaders, base) gen_loss_gan_feature.backward(retain_graph=True) base.sl_enc_optimizer.zero_grad() gen_loss_without_feature.backward() base.dis_optimizer.zero_grad() dis_loss.backward() results['gen_loss_gan_feature'] = gen_loss_gan_feature.item() results[ 'gen_loss_without_feature'] = gen_loss_without_feature.item() results['dis_loss'] = dis_loss.item() if train_reid: cls_loss, triplet_loss, acc = train_reid_an_iter( config, loaders, base, True) reid_loss = cls_loss + triplet_loss reid_loss.backward() results['cls_loss'] = cls_loss.item() results['triplet_loss'] = triplet_loss.item() results['acc'] = acc if self_training: assert trg_labeled_loader is not None, "self training requires labeled target loader" cls_loss, triplet_loss, acc = train_reid_an_iter( config, loaders, base, False, trg_labeled_loader) reid_loss = cls_loss + triplet_loss reid_loss.backward() results['cls_loss'] = cls_loss.item() results['triplet_loss'] = triplet_loss.item() results['acc'] = acc if train_adaptation: gen_loss_real, gen_loss_fake, gen_loss_cycle, dis_loss_real, dis_loss_fake, dis_loss_cycle = \ train_adaptation_an_iter(config,loaders,base) gen_loss = gen_loss_real + gen_loss_fake + gen_loss_cycle dis_loss = dis_loss_real + dis_loss_fake + dis_loss_cycle base.dom_gen_optimizer.zero_grad() gen_loss.backward() base.dom_dis_optimizer.zero_grad() dis_loss.backward() results['dom_gen_loss'] = gen_loss.item() results['dom_dis_loss'] = dis_loss.item() # optimize if optimize_sl_enc: base.sl_enc_optimizer.step() if train_gan: base.gen_optimizer.step() base.dis_optimizer.step() if train_reid or self_training: base.il_enc_optimizer.step() if train_adaptation: base.dom_gen_optimizer.step() base.dom_dis_optimizer.step() # record print("a new iter") meter.update(results) return meter.get_str()
def train_stage2(config, base, loaders): base.set_train() source_loader = loaders.source_loader target_loader = loaders.target_loader meter = MultiItemAverageMeter() for i in range(808): source_imgs, source_pids, source_cids = source_loader.next_one() source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \ source_cids.to(base.device) target_imgs, target_pids, target_cids = target_loader.next_one() target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \ target_cids.to(base.device) source_features = base.feature_extractor(source_imgs) target_features = base.feature_extractor(target_imgs) source_pid_cls_score = base.identity_classifier(source_features) source_cid_cls_score = base.camera_classifier(source_features) multi_view_features = MultiViewReasoning().__call__( source_features, source_features, source_pids, source_pids, source_pid_cls_score) multi_view_cls_score = base.identity_classifier(multi_view_features) source_pid_domain_cls_score = base.identitydomain_classifier( source_features) target_cid_cls_score = base.camera_classifier(target_features) target_pid_domain_cls_score = base.identitydomain_classifier( target_features) source_pid_loss = base.source_pid_creiteron(source_pid_cls_score, source_pids) multi_view_pid_loss = base.source_pid_creiteron( multi_view_cls_score, source_pids) source_pid_domain_loss = base.source_identity_domain_creiteron( source_pid_domain_cls_score, source_pids) target_pid_domain_loss = base.target_identity_domain_creiteron( target_pid_domain_cls_score) source_dispency_loss = base.dispency_creiteron( source_pid_cls_score, source_pid_domain_cls_score) cid_classifier_loss = base.camera_classifier_creiteron( source_cid_cls_score, target_cid_cls_score, source_cids, target_cids) loss = source_pid_loss + multi_view_pid_loss + source_pid_domain_loss + config.lambda1 * target_pid_domain_loss\ - config.lambda2 * source_dispency_loss + config.lambda3 * cid_classifier_loss base.feature_extractor_optimizer.zero_grad() base.identity_classifier_optimizer.zero_grad() base.identitydomain_classifier_optimizer.zero_grad() base.camera_classifier_optimizer.zero_grad() loss.backward() base.identity_classifier_optimizer.step() base.identitydomain_classifier_optimizer.step() base.camera_classifier_optimizer.step() source_features = base.feature_extractor(source_imgs) target_features = base.feature_extractor(target_imgs) source_pid_cls_score = base.identity_classifier(source_features) source_cid_cls_score = base.camera_classifier(source_features) multi_view_features = MultiViewReasoning().__call__( source_features, source_features, source_pids, source_pids, source_pid_cls_score) multi_view_cls_score = base.identity_classifier(multi_view_features) source_pid_domain_cls_score = base.identitydomain_classifier( source_features) target_cid_cls_score = base.camera_classifier(target_features) target_pid_domain_cls_score = base.identitydomain_classifier( target_features) source_pid_loss = base.source_pid_creiteron(source_pid_cls_score, source_pids) multi_view_pid_loss = base.source_pid_creiteron( multi_view_cls_score, source_pids) single_multi_view_feature_loss = base.single_multi_creiteron( source_features, multi_view_features) source_pid_domain_loss = base.source_identity_domain_creiteron( source_pid_domain_cls_score, source_pids) target_pid_domain_loss = base.target_identity_domain_creiteron( target_pid_domain_cls_score) source_dispency_loss = base.dispency_creiteron( source_pid_cls_score, source_pid_domain_cls_score) cid_feature_extractor_loss = base.camera_feature_extractor_creiteron( source_cid_cls_score, target_cid_cls_score) loss = source_pid_loss + multi_view_pid_loss + source_pid_domain_loss + \ config.beta * single_multi_view_feature_loss + config.lambda1 * target_pid_domain_loss \ + config.lambda2 * source_dispency_loss + config.lambda3 * cid_feature_extractor_loss base.feature_extractor_optimizer.zero_grad() base.identity_classifier_optimizer.zero_grad() base.identitydomain_classifier_optimizer.zero_grad() base.camera_classifier_optimizer.zero_grad() loss.backward() base.feature_extractor_optimizer.step() meter.update({ 'source_pid_loss': source_pid_loss.data, 'multi_view_pid_loss': multi_view_pid_loss.data, 'single_multi_view_feature_loss': single_multi_view_feature_loss.data, 'source_pid_domain_loss': source_pid_domain_loss.data, 'target_pid_domain_loss': target_pid_domain_loss.data, 'source_dispency_loss': source_dispency_loss.data, 'cid_classifier_loss': cid_classifier_loss.data, 'cid_feature_extractor_loss': cid_feature_extractor_loss }) return meter.get_val(), meter.get_str()
def train_meta_learning(base, loaders): base.set_train() source_loader = loaders.source_loader target_loader = loaders.target_loader meter = MultiItemAverageMeter() for i in range(400): source_imgs, source_pids, source_cids = IterLoader( source_loader).next_one() source_imgs, source_pids, source_cids = source_imgs.to(base.device), source_pids.to(base.device), \ source_cids.to(base.device) source_features, source_bn_features, source_local_features, source_bn_local_features = base.feature_extractor( source_imgs) source_cls_score = base.classifier1(source_bn_features) source_ide_loss = base.source_ide_creiteron(source_cls_score, source_pids) grad = torch.autograd.grad( source_ide_loss, nn.ModuleList([base.feature_extractor, base.classifier1]).parameters(), retain_graph=True) fast_weights = list( map( lambda p: p[1] - 0.0005 * p[0], zip( grad, nn.ModuleList([base.feature_extractor, base.classifier1]).parameters()))) feature_extractor_param_keys = [] for k, v in base.feature_extractor.state_dict().items(): if k.split('.')[-1] == 'weight' or k.split('.')[-1] == 'bias': feature_extractor_param_keys.append(k) feature_extractor_param_dict = OrderedDict() for i in range(len(fast_weights) - 1): feature_extractor_param_dict[ feature_extractor_param_keys[i]] = fast_weights[i] base.feature_extractor.state_dict().update( feature_extractor_param_dict) target_imgs, target_pids, target_cids = IterLoader( target_loader).next_one() target_imgs, target_pids, target_cids = target_imgs.to(base.device), target_pids.to(base.device), \ target_cids.to(base.device) target_features, target_bn_features, _, _ = base.feature_extractor( target_imgs) target_cls_score = base.classifier2(target_bn_features) target_ide_loss = base.target_ide_creiteron(target_cls_score, target_pids) total_loss = source_ide_loss + target_ide_loss base.feature_extractor_optimizer.zero_grad() base.classifier1_optimizer.zero_grad() base.classifier2_optimizer.zero_grad() total_loss.backward() base.feature_extractor_optimizer.step() base.classifier1_optimizer.step() base.classifier2_optimizer.step() meter.update({ 'source_ide_loss': source_ide_loss.data, 'target_ide_loss': target_ide_loss.data }) return meter.get_val(), meter.get_str()
def train_an_epoch(config, loaders, base, current_epoch, train_gan, train_reid, train_pixel, optimize_sl_enc): # set train mode base.set_train() base.lr_scheduler_step(current_epoch) meter = MultiItemAverageMeter() # train loop for _ in range(200): # zero grad base.sl_enc_optimizer.zero_grad() base.gen_optimizer.zero_grad() base.dis_optimizer.zero_grad() base.il_enc_optimizer.zero_grad() # results = {} # gan if train_gan: gen_loss_without_feature, gen_loss_gan_feature, dis_loss, image_list = train_gan_an_iter( config, loaders, base) gen_loss_gan_feature.backward(retain_graph=True) base.sl_enc_optimizer.zero_grad() gen_loss_without_feature.backward() base.dis_optimizer.zero_grad() dis_loss.backward() results['gen_loss_gan_feature'] = gen_loss_gan_feature.data results['gen_loss_without_feature'] = gen_loss_without_feature.data results['dis_loss'] = dis_loss.data # pixel if train_pixel: assert train_gan pixel_loss = train_pixel_an_iter(config, loaders, base, image_list) (config.weight_pixel_loss * pixel_loss).backward() results['pixel_loss'] = pixel_loss.data # reid if train_reid: cls_loss, triplet_loss, acc = train_reid_an_iter( config, loaders, base) reid_loss = cls_loss + triplet_loss reid_loss.backward() results['cls_loss'] = cls_loss.data results['triplet_loss'] = triplet_loss.data results['acc'] = acc # optimize if optimize_sl_enc: base.sl_enc_optimizer.step() if train_gan: base.gen_optimizer.step() base.dis_optimizer.step() if train_reid: base.il_enc_optimizer.step() # record meter.update(results) return meter.get_str()