示例#1
0
def train(cfg):
    # output
    output_dir = cfg.OUTPUT_DIR
    if os.path.exists(output_dir):
        raise KeyError("Existing path: ", output_dir)
    else:
        os.makedirs(output_dir)

    with open(os.path.join(output_dir, 'config.yaml'), 'w') as f_out:
        print(cfg, file=f_out)

    # logger
    logger = make_logger("project", output_dir, 'log')

    # device
    num_gpus = 0
    if cfg.DEVICE == 'cuda':
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg.DEVICE_ID
        num_gpus = len(cfg.DEVICE_ID.split(','))
        logger.info("Using {} GPUs.\n".format(num_gpus))
    cudnn.benchmark = True
    device = torch.device(cfg.DEVICE)

    # data
    train_loader, query_loader, gallery_loader, num_classes = make_loader(cfg)

    # model
    model = make_model(cfg, num_classes=num_classes)
    if num_gpus > 1:
        model = nn.DataParallel(model)

    # solver
    criterion = make_loss(cfg, num_classes)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)

    # do_train
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      criterion=criterion,
                      logger=logger,
                      scheduler=scheduler,
                      device=device)

    trainer.run(start_epoch=0,
                total_epoch=cfg.SOLVER.MAX_EPOCHS,
                train_loader=train_loader,
                query_loader=query_loader,
                gallery_loader=gallery_loader,
                print_freq=cfg.SOLVER.PRINT_FREQ,
                eval_period=cfg.SOLVER.EVAL_PERIOD,
                out_dir=output_dir)

    print('Done.')
示例#2
0
def main(args):
    ckpt_ = checkpoint.Checkpoint(args)
    # data loader
    dataloader_ = data.Data(args)
    # model build up
    model_ = model.Model(args, ckpt_)
    # loss setting
    loss_ = loss.Loss(args, ckpt_)
    # check module for visualization and gradient check
    check_ = check.check(model_)
    # class for training and testing
    trainer_ = trainer.Trainer(args, model_, loss_, dataloader_, ckpt_, check_)
    if args.test:
        trainer_.test()
        return

    # train
    # train with freeze first
    if args.freeze > 0:
        print('freeze base_params for {} epochs'.format(args.freeze))
    for par in ckpt_.base_params:
        par.requires_grad = False
        if hasattr(model_.get_model(), 'base_params'):
            for par in model_.get_model().base_params:
                par.requires_grad = False

    optim_tmp = optimizer.make_optimizer(args, model_)
    for i in range(args.freeze):
        trainer_.train(optim_tmp)

    # start training
    for par in model_.parameters():
        par.requires_grad = True

    for i in range(trainer_.epoch, args.epochs):
        trainer_.train(trainer_.optimizer)
        if args.test_every != 0 and (i + 1) % args.test_every == 0:
            trainer_.test()
示例#3
0
文件: train.py 项目: manutdzou/ReID
def train(config_file, **kwargs):
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k, v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()

    #PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES)

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = make_logger("Reid_Baseline", output_dir, 'log')
    logger.info("Using {} GPUS".format(1))
    logger.info("Loaded configuration file {}".format(config_file))
    logger.info("Running with config:\n{}".format(cfg))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = torch.device(cfg.DEVICE)
    epochs = cfg.SOLVER.MAX_EPOCHS
    method = cfg.DATALOADER.SAMPLER

    train_loader, val_loader, num_query, num_classes = data_loader(
        cfg, cfg.DATASETS.NAMES)

    model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE)

    if 'center' in method:
        loss_fn, center_criterion = make_loss(cfg)
        optimizer, optimizer_center = make_optimizer_with_center(
            cfg, model, center_criterion)
    else:
        loss_fn = make_loss(cfg)
        optimizer = make_optimizer(cfg, model)

    scheduler = make_scheduler(cfg, optimizer)

    logger.info("Start training")
    since = time.time()
    for epoch in range(epochs):
        count = 0
        running_loss = 0.0
        running_acc = 0
        for data in tqdm(train_loader, desc='Iteration', leave=False):
            model.train()
            images, labels = data
            if device:
                model.to(device)
                images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            if 'center' in method:
                optimizer_center.zero_grad()

            scores, feats = model(images)
            loss = loss_fn(scores, feats, labels)

            loss.backward()
            optimizer.step()
            if 'center' in method:
                for param in center_criterion.parameters():
                    param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT)
                optimizer_center.step()

            count = count + 1
            running_loss += loss.item()
            running_acc += (scores.max(1)[1] == labels).float().mean().item()

        logger.info(
            "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
            .format(epoch + 1, count, len(train_loader), running_loss / count,
                    running_acc / count,
                    scheduler.get_lr()[0]))
        scheduler.step()

        if (epoch + 1) % checkpoint_period == 0:
            model.cpu()
            model.save(output_dir, epoch + 1)

        # Validation
        if (epoch + 1) % eval_period == 0:
            all_feats = []
            all_pids = []
            all_camids = []
            for data in tqdm(val_loader,
                             desc='Feature Extraction',
                             leave=False):
                model.eval()
                with torch.no_grad():
                    images, pids, camids = data
                    if device:
                        model.to(device)
                        images = images.to(device)

                    feats = model(images)

                all_feats.append(feats)
                all_pids.extend(np.asarray(pids))
                all_camids.extend(np.asarray(camids))

            logger.info("start evaluation")
            cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query)
            logger.info("Validation Results - Epoch: {}".format(epoch + 1))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('-' * 10)
示例#4
0
def train(config_file, **kwargs):
    # 1. config
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k, v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = make_logger("Reid_Baseline", output_dir, 'log')
    logger.info("Using {} GPUS".format(1))
    logger.info("Loaded configuration file {}".format(config_file))
    logger.info("Running with config:\n{}".format(cfg))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    device = torch.device(cfg.DEVICE)
    epochs = cfg.SOLVER.MAX_EPOCHS

    # 2. datasets
    # Load the original dataset
    dataset_reference = init_dataset(cfg, cfg.DATASETS.NAMES +
                                     '_origin')  #'Market1501_origin'
    train_set_reference = ImageDataset(dataset_reference.train,
                                       train_transforms)
    train_loader_reference = DataLoader(train_set_reference,
                                        batch_size=128,
                                        shuffle=False,
                                        num_workers=cfg.DATALOADER.NUM_WORKERS,
                                        collate_fn=train_collate_fn)

    # Load the one-shot dataset
    train_loader, val_loader, num_query, num_classes = data_loader(
        cfg, cfg.DATASETS.NAMES)

    # 3. load the model and optimizer
    model = getattr(models, cfg.MODEL.NAME)(num_classes)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)
    loss_fn = make_loss(cfg)
    logger.info("Start training")
    since = time.time()

    top = 0  # the choose of the nearest sample
    top_update = 0  # the first iteration train 80 steps and the following train 40

    # 4. Train and test
    for epoch in range(epochs):
        running_loss = 0.0
        running_acc = 0
        count = 1

        # get nearest samples and reset the model
        if top_update < 80:
            train_step = 80
        else:
            train_step = 40
        if top_update % train_step == 0:
            print("top: ", top)
            A, path_labeled = PSP(model, train_loader_reference, train_loader,
                                  top, cfg)
            top += cfg.DATALOADER.NUM_JUMP
            model = getattr(models, cfg.MODEL.NAME)(num_classes)
            optimizer = make_optimizer(cfg, model)
            scheduler = make_scheduler(cfg, optimizer)
            A_store = A.clone()
        top_update += 1

        for data in tqdm(train_loader, desc='Iteration', leave=False):
            model.train()
            images, labels_batch, img_path = data
            index, index_labeled = find_index_by_path(img_path,
                                                      dataset_reference.train,
                                                      path_labeled)
            images_relevant, GCN_index, choose_from_nodes, labels = load_relevant(
                cfg, dataset_reference.train, index, A_store, labels_batch,
                index_labeled)
            # if device:
            model.to(device)
            images = images_relevant.to(device)

            scores, feat = model(images)
            del images
            loss = loss_fn(scores, feat, labels.to(device), choose_from_nodes)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            count = count + 1
            running_loss += loss.item()
            running_acc += (scores[choose_from_nodes].max(1)[1].cpu() ==
                            labels_batch).float().mean().item()

        scheduler.step()

        # for model save if you need
        # if (epoch+1) % checkpoint_period == 0:
        #     model.cpu()
        #     model.save(output_dir,epoch+1)

        # Validation
        if (epoch + 1) % eval_period == 0:
            all_feats = []
            all_pids = []
            all_camids = []
            for data in tqdm(val_loader,
                             desc='Feature Extraction',
                             leave=False):
                model.eval()
                with torch.no_grad():
                    images, pids, camids = data

                    model.to(device)
                    images = images.to(device)

                    feats = model(images)
                    del images
                all_feats.append(feats.cpu())
                all_pids.extend(np.asarray(pids))
                all_camids.extend(np.asarray(camids))

            cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query)
            logger.info("Validation Results - Epoch: {}".format(epoch + 1))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10, 20]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('-' * 10)
示例#5
0
    def main(self):
        parser = create_parser()
        args = parser.parse_args()

        verify_arguments(args)

        # Load model
        models = [np.load(model)[()] for model in args.model]

        # Load input image
        input_images = []
        for i in args.input:
            input_images.append(itk_image.read(i))

        # Load camera parameters
        geometries = []
        for c, image in zip(args.camera, input_images):
            geometry = render.GeometryContext.load_from_file(c)
            if not args.ignore_image_information:
                geometry.pixel_spacing = image[1]
                geometry.image_size = image[0].shape[::-1]
            geometries.append(geometry)

        # Setup registration framework
        self.renderer = render.SurfaceRenderer()
        self.deformator = transform.SsmDeformator()
        self.calculators = [metric.make_metric(m) for m in args.metric]
        self.optimizer = optimizer.make_optimizer(args.optimizer)

        self.optimizer.set_hyper_parameters(
            **parse_parameter(args.optimizer_params))

        # Set initial guess
        total_dimension = sum(
            [6 + deformator.get_using_dimension(model) for in self.models]
        )
        initial_guess = [0.0 for _ in range(total_dimension)]
        self.optimizer.set_initial_guess(initial_guess)

        # Setup extensional framework
        for f, f_conf in zip_longest(args.framework, args.framework_config):
            assert f is None, \
                'A number of frameworks and configuration file is a mismatch.'
            if f_conf is None or f_conf.lower() == 'none':
                continue

            with open(f_conf) as fp:
                conf = yaml.load(fp)
                self.optimizer.add_framework(f, **conf)

        self.optimizer.add_framework(framework.TqdmReport())
        self.optimizer.add_framework(framework.MatplotReport(self.draw_status))
        self.optimizer.add_framework(HistoryCapture(self.history))

        proj_images = None

        faces = np.array([model['faces'] for model in models])
        with self.optimizer.setup():
            for population in self.optimizer.generate():
                if proj_images is None:
                    proj_images = np.tile(np.expand_dims(
                        proj_image, axis=2), (1, 1, len(population))
                    )

                images, transformed_polys = render_population(
                    population, models, props
                )

                metrics = sum([c.calculate(proj_images, images)
                               for c in self.calculators])

                self.optimizer.update(metrics.tolist())
示例#6
0
                      args=args)


# ensure that the model lives on the cpu (for multi-gpu training)
if args.multi_gpu:
    with tf.device('/cpu:0'):
        model = make_model_impl()
else:
    model = make_model_impl()

print(model.summary())
if args.lsuv_init:
    model = LSUVinit(model, x_train[:args.batch_size, :, :, :])

optimizer = make_optimizer(args.optimizer,
                           lr=args.lr_values[0],
                           momentum=args.momentum)

orig_model = model
if args.multi_gpu:
    model = multi_gpu_model(model)

model.compile(loss=args.loss, optimizer=optimizer, metrics=['accuracy'])


def check_params():
    for layer in model.layers:
        if layer.__class__.__name__ in [
                'BLU', 'PReLU', 'SoftExp', 'ScaledReLU'
        ]:
            for weight in layer.weights:
示例#7
0
def train(config_file, resume=False, **kwargs):
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k, v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()

    # [PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR,dataset) for dataset in cfg.DATASETS.SOURCE]
    # [PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR,dataset) for dataset in cfg.DATASETS.TARGET]
    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = make_logger("Reid_Baseline", output_dir, 'log', resume)
    if not resume:
        logger.info("Using {} GPUS".format(1))
        logger.info("Loaded configuration file {}".format(config_file))
        logger.info("Running with config:\n{}".format(cfg))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = torch.device(cfg.DEVICE)
    epochs = cfg.SOLVER.MAX_EPOCHS

    train_loader, _, _, num_classes = data_loader(cfg,
                                                  cfg.DATASETS.SOURCE,
                                                  merge=cfg.DATASETS.MERGE)

    model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE,
                                            cfg.MODEL.POOL)
    if resume:
        checkpoints = get_last_stats(output_dir)
        try:
            model_dict = torch.load(checkpoints[cfg.MODEL.NAME])
        except KeyError:
            model_dict = torch.load(checkpoints[str(type(model))])
        model.load_state_dict(model_dict)
        if device:
            model.to(device)  # must be done before the optimizer generation
    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)
    base_epo = 0
    if resume:
        optimizer.load_state_dict(torch.load(checkpoints['opt']))
        sch_dict = torch.load(checkpoints['sch'])
        scheduler.load_state_dict(sch_dict)
        base_epo = checkpoints['epo']

    loss_fn = make_loss(cfg)

    if not resume:
        logger.info("Start training")
    since = time.time()
    for epoch in range(epochs):
        count = 0
        running_loss = 0.0
        running_acc = 0
        for data in tqdm(train_loader, desc='Iteration', leave=False):
            model.train()
            images, labels, domains = data
            if device:
                model.to(device)
                images, labels, domains = images.to(device), labels.to(
                    device), domains.to(device)

            optimizer.zero_grad()

            scores, feats = model(images)
            loss = loss_fn(scores, feats, labels)

            loss.backward()
            optimizer.step()

            count = count + 1
            running_loss += loss.item()
            running_acc += (
                scores[0].max(1)[1] == labels).float().mean().item()

        logger.info(
            "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
            .format(epoch + 1 + base_epo, count, len(train_loader),
                    running_loss / count, running_acc / count,
                    scheduler.get_lr()[0]))
        scheduler.step()

        if (epoch + 1 + base_epo) % checkpoint_period == 0:
            model.cpu()
            model.save(output_dir, epoch + 1 + base_epo)
            torch.save(
                optimizer.state_dict(),
                os.path.join(output_dir,
                             'opt_epo' + str(epoch + 1 + base_epo) + '.pth'))
            torch.save(
                scheduler.state_dict(),
                os.path.join(output_dir,
                             'sch_epo' + str(epoch + 1 + base_epo) + '.pth'))

        # Validation
        if (epoch + base_epo + 1) % eval_period == 0:
            # Validation on Target Dataset
            for target in cfg.DATASETS.TARGET:
                mAPs = []
                cmcs = []
                for i in range(iteration):

                    set_seeds(i)

                    _, val_loader, num_query, _ = data_loader(cfg, (target, ),
                                                              merge=False)

                    all_feats = []
                    all_pids = []
                    all_camids = []

                    since = time.time()
                    for data in tqdm(val_loader,
                                     desc='Feature Extraction',
                                     leave=False):
                        model.eval()
                        with torch.no_grad():
                            images, pids, camids = data
                            if device:
                                model.to(device)
                                images = images.to(device)

                            feats = model(images)
                            feats /= feats.norm(dim=-1, keepdim=True)

                        all_feats.append(feats)
                        all_pids.extend(np.asarray(pids))
                        all_camids.extend(np.asarray(camids))

                    cmc, mAP = evaluation(all_feats, all_pids, all_camids,
                                          num_query)
                    mAPs.append(mAP)
                    cmcs.append(cmc)

                mAP = np.mean(np.array(mAPs))
                cmc = np.mean(np.array(cmcs), axis=0)

                mAP_std = np.std(np.array(mAPs))
                cmc_std = np.std(np.array(cmcs), axis=0)

                logger.info("Validation Results: {} - Epoch: {}".format(
                    target, epoch + 1 + base_epo))
                logger.info("mAP: {:.1%} (std: {:.3%})".format(mAP, mAP_std))
                for r in [1, 5, 10]:
                    logger.info(
                        "CMC curve, Rank-{:<3}:{:.1%} (std: {:.3%})".format(
                            r, cmc[r - 1], cmc_std[r - 1]))

            reset()

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('-' * 10)
示例#8
0
def train(config_file1, config_file2, **kwargs):
    # 1. config
    cfg.merge_from_file(config_file1)
    if kwargs:
        opts = []
        for k, v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    #cfg.freeze()
    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = make_logger("Reid_Baseline", output_dir, 'log')
    #logger.info("Using {} GPUS".format(1))
    logger.info("Loaded configuration file {}".format(config_file1))
    logger.info("Running with config:\n{}".format(cfg))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    #device = torch.device(cfg.DEVICE)
    epochs = cfg.SOLVER.MAX_EPOCHS

    # 2. datasets
    # Load the original dataset
    #dataset_reference = init_dataset(cfg, cfg.DATASETS.NAMES )
    dataset_reference = init_dataset(cfg, cfg.DATASETS.NAMES +
                                     '_origin')  #'Market1501_origin'
    train_set_reference = ImageDataset(dataset_reference.train,
                                       train_transforms)
    train_loader_reference = DataLoader(train_set_reference,
                                        batch_size=128,
                                        shuffle=False,
                                        num_workers=cfg.DATALOADER.NUM_WORKERS,
                                        collate_fn=train_collate_fn)
    #不用放到网络里,所以不用transform

    # Load the one-shot dataset
    train_loader, val_loader, num_query, num_classes = data_loader(
        cfg, cfg.DATASETS.NAMES)

    # 3. load the model and optimizer
    model = getattr(models, cfg.MODEL.NAME)(num_classes)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)
    loss_fn = make_loss(cfg)
    logger.info("Start training")
    since = time.time()
    if torch.cuda.device_count() > 1:
        print("Use", torch.cuda.device_count(), 'gpus')
    elif torch.cuda.device_count() == 1:
        print("Use", torch.cuda.device_count(), 'gpu')
    model = nn.DataParallel(model)
    top = 0  # the choose of the nearest sample
    top_update = 0  # the first iteration train 80 steps and the following train 40
    train_time = 0  #1表示训练几次gan
    bound = 1  #究竟训练几次,改成多次以后再说
    lock = False
    train_compen = 0
    # 4. Train and test
    for epoch in range(epochs):
        running_loss = 0.0
        running_acc = 0
        count = 1
        # get nearest samples and reset the model
        if top_update < 80:
            train_step = 80
            #重新gan生成的图像第一次是否需要训练80次,看看是否下一次输入的图片变少了吧
        else:
            train_step = 40
        #if top_update % train_step == 0:
        if top_update % train_step == 0 and train_compen == 0:
            print("top: ", top)
            #作者原来的实验top取到41,这里折中(是否要折中也是个实验测试的点)
            #if 1==1:
            if top >= 8 and train_time < bound:
                train_compen = (top - 1) * 40 + 80
                #build_image(A,train_loader_reference,train_loader)
                train_time += 1
                #gan的训练模式
                mode = 'train'
                retrain(mode)
                #gan生成图像到原来数据集
                produce()
                cfg.merge_from_file(config_file2)
                output_dir = cfg.OUTPUT_DIR
                if output_dir and not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                logger = make_logger("Reid_Baseline", output_dir, 'log')
                logger.info(
                    "Loaded configuration file {}".format(config_file2))
                logger.info("Running with config:\n{}".format(cfg))
                dataset_reference = init_dataset(
                    cfg, cfg.DATASETS.NAMES + '_origin')  #'Market1501_origin'
                train_set_reference = ImageDataset(dataset_reference.train,
                                                   train_transforms)
                train_loader_reference = DataLoader(
                    train_set_reference,
                    batch_size=128,
                    shuffle=False,
                    num_workers=cfg.DATALOADER.NUM_WORKERS,
                    collate_fn=train_collate_fn)
                dataset_ref = init_dataset(cfg, cfg.DATASETS.NAMES +
                                           '_ref')  #'Market1501_origin'
                train_set_ref = ImageDataset(dataset_ref.train,
                                             train_transforms)
                train_loader_ref = DataLoader(
                    train_set_ref,
                    batch_size=128,
                    shuffle=False,
                    num_workers=cfg.DATALOADER.NUM_WORKERS,
                    collate_fn=train_collate_fn)
                lock = True
            if lock == True:
                A, path_labeled = PSP2(model, train_loader_reference,
                                       train_loader, train_loader_ref, top,
                                       logger, cfg)
                lock = False
            else:
                A, path_labeled = PSP(model, train_loader_reference,
                                      train_loader, top, logger, cfg)

            #vis = len(train_loader_reference.dataset)
            #A= torch.ones(vis, len(train_loader_reference.dataset))
            #build_image(A,train_loader_reference,train_loader)
            top += cfg.DATALOADER.NUM_JUMP
            model = getattr(models, cfg.MODEL.NAME)(num_classes)
            model = nn.DataParallel(model)
            optimizer = make_optimizer(cfg, model)
            scheduler = make_scheduler(cfg, optimizer)
            A_store = A.clone()
        top_update += 1

        for data in tqdm(train_loader, desc='Iteration', leave=False):
            model.train()
            images, labels_batch, img_path = data
            index, index_labeled = find_index_by_path(img_path,
                                                      dataset_reference.train,
                                                      path_labeled)
            images_relevant, GCN_index, choose_from_nodes, labels = load_relevant(
                cfg, dataset_reference.train, index, A_store, labels_batch,
                index_labeled)
            # if device:
            model.to(device)
            images = images_relevant.to(device)

            scores, feat = model(images)
            del images
            loss = loss_fn(scores, feat, labels.to(device), choose_from_nodes)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            count = count + 1
            running_loss += loss.item()
            running_acc += (scores[choose_from_nodes].max(1)[1].cpu() ==
                            labels_batch).float().mean().item()

        scheduler.step()

        # for model save if you need
        # if (epoch+1) % checkpoint_period == 0:
        #     model.cpu()
        #     model.save(output_dir,epoch+1)

        # Validation
        if (epoch + 1) % eval_period == 0:
            all_feats = []
            all_pids = []
            all_camids = []
            for data in tqdm(val_loader,
                             desc='Feature Extraction',
                             leave=False):
                model.eval()
                with torch.no_grad():
                    images, pids, camids = data

                    model.to(device)
                    images = images.to(device)

                    feats = model(images)
                    del images
                all_feats.append(feats.cpu())
                all_pids.extend(np.asarray(pids))
                all_camids.extend(np.asarray(camids))

            cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query)
            logger.info("Validation Results - Epoch: {}".format(epoch + 1))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10, 20]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))
        if train_compen > 0:
            train_compen -= 1

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('-' * 10)