コード例 #1
0
ファイル: baseline.py プロジェクト: zhudi512/PDA-Net
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    else:
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
    # print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (256, 128)
    dataset, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers,
                 args.combine_trainval, args.np_ratio)

    # Create model
    base_model = models.create(args.arch, cut_at_pooling=True)
    embed_model = EltwiseSubEmbed(use_batch_norm=True, use_classifier=True,
                                      num_features=2048, num_classes=2)
    model = SiameseNet(base_model, embed_model)
    model = nn.DataParallel(model).cuda()

    # Evaluator
    evaluator = CascadeEvaluator(
        torch.nn.DataParallel(base_model).cuda(),
        embed_model,
        embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0])

    # Load from checkpoint
    best_mAP = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        model.load_state_dict(checkpoint)

        print("Test the loaded model:")
        top1, mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, rerank_topk=100, dataset=args.dataset)
        best_mAP = mAP

    if args.evaluate:
        return

    # Criterion
    criterion = nn.CrossEntropyLoss().cuda()
    # Optimizer
    param_groups = [
        {'params': model.module.base_model.parameters(), 'lr_mult': 1.0},
        {'params': model.module.embed_model.parameters(), 'lr_mult': 1.0}]
    optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # Trainer
    trainer = SiameseTrainer(model, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        lr = args.lr * (0.1 ** (epoch // args.step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(0, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer, base_lr=args.lr)

        if epoch % args.eval_step==0:
            mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False)
            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            save_checkpoint({
                'state_dict': model.state_dict()
            }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

            print('\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.
                  format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, dataset=args.dataset)
コード例 #2
0
def main():
    opt = Options().parse()
    dataset, train_loader, test_loader = get_data(opt.dataset, opt.dataroot,
                                                  opt.height, opt.width,
                                                  opt.batch_size, opt.workers,
                                                  opt.pose_aug)

    dataset_size = len(dataset.trainval) * 4
    print('#training images = %d' % dataset_size)

    model = FDGANModel(opt)
    visualizer = Visualizer(opt)

    evaluator = CascadeEvaluator(
        torch.nn.DataParallel(model.net_E.module.base_model).cuda(),
        model.net_E.module.embed_model,
        embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0])
    if opt.stage != 1:
        print('Test with baseline model:')
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       rerank_topk=100,
                                       dataset=opt.dataset)
        message = '\n Test with baseline model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        visualizer.print_reid_results(message)

    total_steps = 0
    best_mAP = 0
    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
        model.reset_model_status()

        for i, data in enumerate(train_loader):
            iter_start_time = time.time()
            visualizer.reset()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

        if epoch % opt.save_step == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save(epoch)

        if epoch % opt.eval_step == 0 and opt.stage != 1:
            mAP = evaluator.evaluate(val_loader,
                                     dataset.val,
                                     dataset.val,
                                     top1=False)
            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            if is_best:
                model.save('best')
            message = '\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.format(
                epoch, mAP, best_mAP, ' *' if is_best else '')
            visualizer.print_reid_results(message)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate()

    # Final test
    if opt.stage != 1:
        print('Test with best model:')
        checkpoint = load_checkpoint(
            osp.join(opt.checkpoints, opt.name,
                     '%s_net_%s.pth' % ('best', 'E')))
        model.net_E.load_state_dict(checkpoint)
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       rerank_topk=100,
                                       dataset=opt.dataset)
        message = '\n Test with best model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        visualizer.print_reid_results(message)
コード例 #3
0
ファイル: main.py プロジェクト: leule/DCDS
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    assert args.num_instances > 1, "num_instances should be greater than 1"
    assert args.batch_size % args.num_instances == 0, \
        'num_instances should divide batch_size'
    if args.height is None or args.width is None:
        args.height, args.width = (144, 56) if args.arch == 'inception' else \
                                  (384, 128)
    dataset, num_classes, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.num_instances, args.workers,
                 args.combine_trainval)

    # Create model
    # Hacking here to let the classifier be the last feature embedding layer
    # Net structure: avgpool -> FC(1024) -> FC(args.features)
    base_model = models.create(args.arch,
                               num_features=1024,
                               cut_at_pooling=True,
                               dropout=args.dropout,
                               num_classes=args.features)

    grp_num = args.grp_num
    embed_model = [
        VNetEmbed(instances_num=args.num_instances,
                  feat_num=(2048 / grp_num),
                  num_classes=2,
                  drop_ratio=args.dropout).cuda() for i in range(grp_num)
    ]

    base_model = nn.DataParallel(base_model).cuda()

    model = VNetExtension(
        instances_num=args.num_instances,  # 
        base_model=base_model,
        embed_model=embed_model,
        alpha=args.alpha)

    if args.retrain:
        if args.evaluate_from:
            print('loading trained model...')
            checkpoint = load_checkpoint(args.evaluate_from)
            model.load_state_dict(checkpoint['state_dict'])

        else:
            print('loading base part of pretrained model...')
            checkpoint = load_checkpoint(args.retrain)
            #copy_state_dict(checkpoint['state_dict'], base_model, strip='base.module.', replace='module.')
            copy_state_dict(checkpoint['state_dict'],
                            base_model,
                            strip='base_model.',
                            replace='')
            print('loading embed part of pretrained model...')
            if grp_num > 1:
                for i in range(grp_num):
                    copy_state_dict(checkpoint['state_dict'],
                                    embed_model[i],
                                    strip='embed_model.bn_' + str(i) + '.',
                                    replace='bn.')
                    copy_state_dict(checkpoint['state_dict'],
                                    embed_model[i],
                                    strip='embed_model.classifier_' + str(i) +
                                    '.',
                                    replace='classifier.')
            else:
                copy_state_dict(checkpoint['state_dict'],
                                embed_model[0],
                                strip='module.embed_model.',
                                replace='')

    # Distance metric
    metric = DistanceMetric(algorithm=args.dist_metric)

    # Load from checkpoint
    start_epoch = best_top1 = 0
    best_mAP = 0

    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(
            start_epoch, best_top1))

    # Evaluator
    evaluator = CascadeEvaluator(
        base_model,
        embed_model,
        embed_dist_fn=lambda x: F.softmax(x, dim=1).data[:, 0])
    #embed_dist_fn=lambda x: F.softmax(x))# here we are performing softmax normalization, this function take N,2 vector and after normalizing both column it return the
    #first column

    if args.evaluate:
        metric.train(model, train_loader)
        if args.evaluate_from:
            print('loading trained model...')
            checkpoint = load_checkpoint(args.evaluate_from)
            model.load_state_dict(checkpoint['state_dict'])
        print("Test:")
        evaluator.evaluate(test_loader,
                           dataset.query,
                           dataset.gallery,
                           args.alpha,
                           metric,
                           rerank_topk=args.rerank,
                           dataset=args.dataset)
        return

    # Criterion
    criterion = nn.CrossEntropyLoss().cuda()

    criterion2 = TripletLoss(margin=args.margin).cuda()

    #criterion = nn.BCELoss().cuda()

    # base lr rate and embed lr rate

    new_params = [z for z in model.embed]
    param_groups = [
                    {'params': model.base.module.base.parameters(), 'lr_mult': 1.0}] + \
                   [{'params': new_params[i].parameters(), 'lr_mult': 10.0} for i in range(grp_num)]

    # Optimizer

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Trainer
    trainer = DCDSBase(model, criterion, criterion2, args.alpha, grp_num)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = args.ss if args.arch == 'inception' else 20
        lr = args.lr * (0.1**(epoch // step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)
        return lr

    # Start training
    for epoch in range(start_epoch, args.epochs):
        lr = adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer, lr, warm_up=False)
        top1, mAP = evaluator.evaluate(val_loader,
                                       dataset.val,
                                       dataset.val,
                                       args.alpha,
                                       rerank_topk=args.rerank,
                                       second_stage=True,
                                       dataset=args.dataset)

        is_best = top1 > best_top1
        best_mAP = max(mAP, best_mAP)
        save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1,
            },
            is_best,
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.
              format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    metric.train(model, train_loader)
    evaluator.evaluate(test_loader,
                       dataset.query,
                       dataset.gallery,
                       args.alpha,
                       metric,
                       rerank_topk=args.rerank,
                       dataset=args.dataset)
コード例 #4
0
ファイル: train.py プロジェクト: zhudi512/PDA-Net
def main():
    opt = Options().parse()
    source_dataset, source_train_loader, source_test_loader, target_dataset, target_train_loader, target_test_loader = get_cross_data(
        opt)

    model = PDANetModel(opt)
    writer = SummaryWriter()

    evaluator = CascadeEvaluator(
        torch.nn.DataParallel(model.net_E.module.base_model).cuda(),
        model.net_E.module.embed_model,
        embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0])

    if opt.stage == 0:
        print('Test with baseline model:')
        top1, mAP = evaluator.evaluate(target_test_loader,
                                       target_dataset.query,
                                       target_dataset.gallery,
                                       second_stage=False,
                                       rerank_topk=100,
                                       dataset=opt.dataset)
        message = '\n Test with baseline model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        print(message)

    total_steps = 0
    best_mAP = 0
    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
        model.reset_model_status()

        target_iter = enumerate(target_train_loader)

        for i, source_data in enumerate(source_train_loader):
            """ Load Target Data at the same time"""
            try:
                _, target_data = next(target_iter)
            except:
                target_iter = enumerate(target_train_loader)
                _, target_data = next(target_iter)

            iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            model.set_inputs(target_data, source_data)
            model.optimize_cross_parameters()

            if total_steps % 10000 == 0:
                top1, mAP = evaluator.evaluate(target_test_loader,
                                               target_dataset.query,
                                               target_dataset.gallery,
                                               second_stage=False,
                                               rerank_topk=100,
                                               dataset=opt.dataset)
                writer.add_scalar('tgt_rank1', top1, int(total_steps / 10000))
                writer.add_scalar('tgt_mAP', mAP, int(total_steps / 10000))

            #Display visual results
            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0

                #TB visualization
                visual_data = model.get_tf_cross_visuals()
                for key, value in visual_data.items():
                    writer.add_image(key, make_grid(value, nrow=16),
                                     total_steps)

            #Plot curves
            if total_steps % opt.print_freq == 0:
                #                 errors = model.get_current_errors()
                errors = model.get_current_cross_errors()
                #TB scalar
                for key, value in errors.items():
                    writer.add_scalar(key, value, total_steps)

                t = (time.time() - iter_start_time) / opt.batch_size

        if epoch % opt.save_step == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save(epoch)

        if epoch % opt.eval_step == 0 and opt.stage != 1:

            top1, mAP = evaluator.evaluate(target_test_loader,
                                           target_dataset.query,
                                           target_dataset.gallery,
                                           second_stage=False,
                                           rerank_topk=100,
                                           dataset=opt.dataset)

            writer.add_scalar('tar_rank1', top1, epoch)
            writer.add_scalar('tar_mAP', mAP, epoch)

            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            if is_best:
                model.save('best')
            message = '\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.format(
                epoch, mAP, best_mAP, ' *' if is_best else '')
            print(message)

            #=========source test=========
            top1, mAP = evaluator.evaluate(source_test_loader,
                                           source_dataset.query,
                                           source_dataset.gallery,
                                           second_stage=False,
                                           rerank_topk=100,
                                           dataset=opt.dataset)

            writer.add_scalar('src_rank1', top1, epoch)
            writer.add_scalar('src_mAP', mAP, epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate()

    # Final test
    if opt.stage != 1:
        print('Test with best model:')
        checkpoint = load_checkpoint(
            osp.join(opt.checkpoints, opt.name,
                     '%s_net_%s.pth' % ('best', 'E')))
        model.net_E.load_state_dict(checkpoint)
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       rerank_topk=100,
                                       second_stage=False,
                                       dataset=opt.dataset)
        message = '\n Test with best model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        print(message)
コード例 #5
0
ファイル: baseline.py プロジェクト: azuxmioy/ST-ReIDNet
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    else:
        log_dir = osp.dirname(args.resume)
        sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))
    # print("==========\nArgs:{}\n==========".format(args))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (256, 128)
    dataset, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers,
                 args.combine_trainval, args.np_ratio,
                args.emb_type, args.inst_mode, args.eraser)

    if args.combine_trainval:
        emb_size = dataset.num_trainval_ids
    else:
        emb_size = dataset.num_train_ids

    # Create model
    if (args.emb_type == 'Single'):
        model = SingleNet(args.arch,
                          emb_size,
                          pretraind=True,
                          use_bn=args.use_bn,
                          test_bn=args.test_bn,
                          last_stride=args.last_stride)
    elif (args.emb_type == 'Siamese'):
        model = SiameseNet(args.arch,
                           emb_size,
                           pretraind=True,
                           use_bn=args.use_bn,
                           test_bn=args.test_bn,
                           last_stride=args.last_stride)
    else:
        raise ValueError('unrecognized model')
    model = nn.DataParallel(model).cuda()

    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        model.load_state_dict(checkpoint)

    # Evaluator

    evaluator = CascadeEvaluator(torch.nn.DataParallel(model).cuda(),
                                 emb_size=emb_size)

    # Load from checkpoint
    best_mAP = 0
    if args.resume:
        print("Test the loaded model:")
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       dataset=args.dataset)
        best_mAP = mAP
    if args.evaluate:
        return

    # Criterion
    if args.soft_margin:
        tri_criterion = TripletLoss(margin='soft').cuda()
    else:
        tri_criterion = TripletLoss(margin=args.margin).cuda()

    if (args.emb_type == 'Single'):
        if args.label_smoothing:
            cla_criterion = CrossEntropyLabelSmooth(emb_size,
                                                    epsilon=0.1).cuda()
        else:
            cla_criterion = torch.nn.CrossEntropyLoss().cuda()
    elif (args.emb_type == 'Siamese'):
        cla_criterion = torch.nn.CrossEntropyLoss().cuda()

    # Optimizer
    param_groups = [{
        'params': model.module.base_model.parameters(),
        'lr_mult': 0.1
    }, {
        'params': model.module.classifier.parameters(),
        'lr_mult': 1.0
    }]

    if (args.opt_name == 'SGD'):
        optimizer = getattr(torch.optim,
                            args.opt_name)(param_groups,
                                           lr=args.lr,
                                           weight_decay=args.weight_decay,
                                           momentum=args.momentum)
    else:
        optimizer = getattr(torch.optim,
                            args.opt_name)(param_groups,
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)

    # Trainer

    if (args.emb_type == 'Single'):
        trainer = TripletTrainer(model, tri_criterion, cla_criterion,
                                 args.lambda_tri, args.lambda_cla)
    elif (args.emb_type == 'Siamese'):
        trainer = SiameseTrainer(model, tri_criterion, cla_criterion,
                                 args.lambda_tri, args.lambda_cla)

    #TODO:Warmup lr
    # Schedule learning rate
    def adjust_lr(epoch):

        lr = args.lr * (0.1**(epoch // args.step_size))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(0, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer, base_lr=args.lr)

        if epoch % args.eval_step == 0:
            #mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False, dataset=args.dataset)
            mAP = evaluator.evaluate(test_loader,
                                     dataset.query,
                                     dataset.gallery,
                                     top1=False,
                                     dataset=args.dataset)
            is_best = mAP > best_mAP
            best_mAP = max(mAP, best_mAP)
            save_checkpoint({'state_dict': model.state_dict()},
                            is_best,
                            fpath=osp.join(args.logs_dir,
                                           'checkpoint.pth.tar'))

            print('\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.
                  format(epoch, mAP, best_mAP, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.load_state_dict(checkpoint['state_dict'])
    evaluator.evaluate(test_loader,
                       dataset.query,
                       dataset.gallery,
                       dataset=args.dataset)
コード例 #6
0
ファイル: train_visual.py プロジェクト: zhudi512/PDA-Net
def main():
    opt = Options().parse()
    source_dataset, source_train_loader, source_test_loader, target_dataset, target_train_loader, target_test_loader = get_cross_data(
        opt)

    #     dataset, train_loader, test_loader = get_data(opt.dataset, opt.dataroot, opt.height, opt.width, opt.batch_size, opt.workers, opt.pose_aug)

    dataset_size = len(source_dataset.trainval) * 4
    print('#souce training images = %d' % dataset_size)

    model = CPDGANModel(opt)
    #     visualizer = Visualizer(opt)
    writer = SummaryWriter()

    evaluator = CascadeEvaluator(
        torch.nn.DataParallel(model.net_E.module.base_model).cuda(),
        model.net_E.module.embed_model,
        embed_dist_fn=lambda x: F.softmax(Variable(x), dim=1).data[:, 0])

    if opt.stage == 0:
        print('Test with baseline model:')
        top1, mAP = evaluator.evaluate(target_test_loader,
                                       target_dataset.query,
                                       target_dataset.gallery,
                                       second_stage=False,
                                       rerank_topk=100,
                                       dataset=opt.dataset)
        message = '\n Test with baseline model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        print(message)
#         visualizer.print_reid_results(message)

    total_steps = 0
    best_mAP = 0
    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
        model.reset_model_status()

        target_iter = enumerate(target_train_loader)

        for i, source_data in enumerate(source_train_loader):
            """ Load Target Data """
            try:
                _, target_data = next(target_iter)
            except:
                target_iter = enumerate(target_train_loader)
                _, target_data = next(target_iter)
#             print("target datatype",target_data[0].keys())

            iter_start_time = time.time()
            #             visualizer.reset()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            model.set_inputs(target_data, source_data)
            model.forward_test_cross()

            #Display visual results
            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0

                #TB visualization
                visual_data = model.get_test_cross_visuals()
                for key, value in visual_data.items():
                    writer.add_image(key, make_grid(value, nrow=24),
                                     total_steps)
                    save_image(value,
                               './save_images/{}_{}.png'.format(
                                   total_steps, key),
                               nrow=24)

#                 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

#             #Plot curves
            if total_steps % opt.print_freq == 0:
                # #                 errors = model.get_current_errors()
                #                 errors = model.get_current_cross_errors()
                #                 #TB scalar
                #                 for key, value in errors.items():
                #                     writer.add_scalar(key, value, total_steps)

                t = (time.time() - iter_start_time) / opt.batch_size
                message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
                print(message)
#                 print_current_errors(epoch, epoch_iter, errors, t)
#                 visualizer.print_current_errors(epoch, epoch_iter, errors, t)
#                 if opt.display_id > 0:
#                     visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

#         if epoch % opt.save_step == 0:
#             print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
#             model.save(epoch)

#         if epoch % opt.eval_step == 0 and opt.stage!=1:

#             top1, mAP = evaluator.evaluate(target_test_loader, target_dataset.query, target_dataset.gallery, second_stage=False, rerank_topk=100, dataset=opt.dataset)
# #             mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False)

#             writer.add_scalar('tar_rank1', top1, epoch)
#             writer.add_scalar('tar_mAP', mAP, epoch)

#             is_best = mAP > best_mAP
#             best_mAP = max(mAP, best_mAP)
#             if is_best:
#                 model.save('best')
#             message = '\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n'.format(epoch, mAP, best_mAP, ' *' if is_best else '')
#             print(message)
# #             visualizer.print_reid_results(message)

#             #=========source test=========
#             top1, mAP = evaluator.evaluate(source_test_loader, source_dataset.query, source_dataset.gallery, second_stage=False, rerank_topk=100, dataset=opt.dataset)
# #             mAP = evaluator.evaluate(val_loader, dataset.val, dataset.val, top1=False)

#             writer.add_scalar('src_rank1', top1, epoch)
#             writer.add_scalar('src_mAP', mAP, epoch)

#         print('End of epoch %d / %d \t Time Taken: %d sec' %
#               (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
#         model.update_learning_rate()

# Final test
    if opt.stage != 1:
        print('Test with best model:')
        checkpoint = load_checkpoint(
            osp.join(opt.checkpoints, opt.name,
                     '%s_net_%s.pth' % ('best', 'E')))
        model.net_E.load_state_dict(checkpoint)
        top1, mAP = evaluator.evaluate(test_loader,
                                       dataset.query,
                                       dataset.gallery,
                                       rerank_topk=100,
                                       second_stage=False,
                                       dataset=opt.dataset)
        message = '\n Test with best model:  mAP: {:5.1%}  top1: {:5.1%}\n'.format(
            mAP, top1)
        print(message)