Beispiel #1
0
def main(args):
    path = os.path.join(os.getcwd(), 'soft_label', 'soft_label_resnet50.txt')
    if not os.path.isfile(path):
        print('soft label file is not exist')

    train_loader = getTrainLoader(args, path)
    _, val_loader, num_query, num_classes, train_size = make_data_loader(args)

    #train_loader, val_loader, num_query, num_classes, train_size = make_data_loader(args)
    model = build_model(args, num_classes)
    optimizer = make_optimizer(args, model)
    scheduler = WarmupMultiStepLR(optimizer, [30, 55], 0.1, 0.01, 5, "linear")

    loss_func = make_loss(args)

    model.to(device)

    for epoch in range(args.Epochs):
        model.train()
        running_loss = 0.0
        running_klloss = 0.0
        running_softloss = 0.0
        running_corrects = 0.0
        for index, data in enumerate(tqdm(train_loader)):
            img, target, soft_target = data
            img = img.cuda()
            target = target.cuda()
            soft_target = soft_target.cuda()
            score, _ = model(img)
            preds = torch.max(score.data, 1)[1]
            loss, klloss, softloss = loss_func(score, target, soft_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_klloss += klloss.item()
            running_softloss += softloss.item()
            running_corrects += float(torch.sum(preds == target.data))

        scheduler.step()
        epoch_loss = running_loss / train_size
        epoch_klloss = running_klloss / train_size
        epoch_softloss = running_softloss / train_size
        epoch_acc = running_corrects / train_size
        print(
            "Epoch {}   Loss : {:.4f} KLLoss:{:.8f}  SoftLoss:{:.4f}  Acc:{:.4f}"
            .format(epoch, epoch_loss, epoch_klloss, epoch_softloss,
                    epoch_acc))

        if (epoch + 1) % args.n_save == 0:
            evaluator = Evaluator(model, val_loader, num_query)
            cmc, mAP = evaluator.run()
            print('---------------------------')
            print("CMC Curve:")
            for r in [1, 5, 10]:
                print("Rank-{} : {:.1%}".format(r, cmc[r - 1]))
            print("mAP : {:.1%}".format(mAP))
            print('---------------------------')
            save_model(args, model, optimizer, epoch)
Beispiel #2
0
def main(args):
    sys.stdout = Logger(
        os.path.join(args.log_path, args.log_description,
                     'log' + time.strftime(".%m_%d_%H:%M:%S") + '.txt'))

    train_loader, val_loader, num_query, num_classes, train_size = make_data_loader(
        args)
    model = build_model(args, num_classes)
    print(model)
    optimizer = make_optimizer(args, model)
    scheduler = WarmupMultiStepLR(optimizer, [30, 55], 0.1, 0.01, 5, "linear")

    loss_func = make_loss(args)

    model.to(device)

    for epoch in range(args.Epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0.0
        for index, data in enumerate(tqdm(train_loader)):
            img, target = data
            img = img.cuda()
            target = target.cuda()
            score, _ = model(img)
            preds = torch.max(score.data, 1)[1]
            loss = loss_func(score, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_corrects += float(torch.sum(preds == target.data))

        scheduler.step()
        epoch_loss = running_loss / train_size
        epoch_acc = running_corrects / train_size
        print("Epoch {}   Loss : {:.6f}   Acc:{:.4f}".format(
            epoch, epoch_loss, epoch_acc))

        if (epoch + 1) % args.n_save == 0:
            evaluator = Evaluator(model, val_loader, num_query)
            cmc, mAP = evaluator.run()
            print('---------------------------')
            print("CMC Curve:")
            for r in [1, 5, 10]:
                print("Rank-{} : {:.1%}".format(r, cmc[r - 1]))
            print("mAP : {:.1%}".format(mAP))
            print('---------------------------')
            save_model(args, model, optimizer, epoch)
Beispiel #3
0
def main(args):

    # transform
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])



    cap_transform = None

    # data
    train_loader = data_config(args.image_dir, args.anno_dir, args.batch_size, 'train', 100, train_transform, cap_transform=cap_transform)

    test_loader = data_config(args.image_dir, args.anno_dir, 64, 'test', 100, test_transform)
    unique_image = get_image_unique(args.image_dir, args.anno_dir, 64, 'test', 100, test_transform)  
    
    # loss
    compute_loss = Loss(args)
    nn.DataParallel(compute_loss).cuda()

    # network
    network, optimizer = network_config(args, 'train', compute_loss.parameters(), args.resume, args.model_path)

    # lr_scheduler
    scheduler = WarmupMultiStepLR(optimizer, (20, 25, 35), 0.1, 0.01, 10, 'linear')

    
    ac_t2i_top1_best = 0.0
    best_epoch = 0
    for epoch in range(args.num_epoches - args.start_epoch):
        network.train()
        # train for one epoch
        train_loss, train_time, image_precision, text_precision = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args)

        # evaluate on validation set
        is_best = False
        print('Train done for epoch-{}'.format(args.start_epoch + epoch))

        logging.info('Epoch:  [{}|{}], train_time: {:.3f}, train_loss: {:.3f}'.format(args.start_epoch + epoch, args.num_epoches, train_time, train_loss))
        logging.info('image_precision: {:.3f}, text_precision: {:.3f}'.format(image_precision, text_precision))
        scheduler.step()
        for param in optimizer.param_groups:
            print('lr:{}'.format(param['lr']))

        if epoch >= 0:
            ac_top1_i2t, ac_top5_i2t, ac_top10_i2t, ac_top1_t2i, ac_top5_t2i , ac_top10_t2i, test_time = test(test_loader, network, args, unique_image)
        
            state = {'network': network.state_dict(), 'optimizer': optimizer.state_dict(), 'W': compute_loss.W, 'epoch': args.start_epoch + epoch}
           
            if ac_top1_t2i > ac_t2i_top1_best:
                best_epoch = epoch
                ac_t2i_top1_best = ac_top1_t2i
                save_checkpoint(state, epoch, args.checkpoint_dir, is_best)
            
            logging.info('epoch:{}'.format(epoch))
            logging.info('top1_t2i: {:.3f}, top5_t2i: {:.3f}, top10_t2i: {:.3f}, top1_i2t: {:.3f}, top5_i2t: {:.3f}, top10_i2t: {:.3f}'.format(
            ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, ac_top1_i2t, ac_top5_i2t, ac_top10_i2t))
       

    logging.info('Best epoch:{}'.format(best_epoch))
    logging.info('Train done')
    logging.info(args.checkpoint_dir)
    logging.info(args.log_dir)
Beispiel #4
0
class BaseModel(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self._init_models()
        self._init_optimizers()

        print('---------- Networks initialized -------------')
        print_network(self.Content_Encoder)
        print('-----------------------------------------------')

    def _init_models(self):
        # -----------------Content_Encoder-------------------
        self.Content_Encoder = Baseline(self.cfg.DATASETS.NUM_CLASSES_S, 1, self.cfg.MODEL.PRETRAIN_PATH, 'bnneck',
                                      'after', self.cfg.MODEL.NAME, 'imagenet')
        # -----------------Criterion----------------- #
        self.xent = CrossEntropyLabelSmooth(num_classes=self.cfg.DATASETS.NUM_CLASSES_S).cuda()
        self.triplet = TripletLoss(0.3)
        self.Smooth_L1_loss = torch.nn.SmoothL1Loss(reduction='mean').cuda()
        # --------------------Cuda------------------- #
        self.Content_Encoder = torch.nn.DataParallel(self.Content_Encoder).cuda()

    def _init_optimizers(self):
        self.Content_optimizer = make_optimizer(self.cfg, self.Content_Encoder)
        self.Content_optimizer_fix = make_optimizer(self.cfg, self.Content_Encoder, fix=True)
        self.scheduler = WarmupMultiStepLR(self.Content_optimizer, (30, 55), 0.1, 1.0 / 3,
                                           500, "linear")
        self.scheduler_fix = WarmupMultiStepLR(self.Content_optimizer_fix, (30, 55), 0.1, 1.0 / 3,
                                           500, "linear")
        self.schedulers = []
        self.optimizers = []

    def reset_model_status(self):
        self.Content_Encoder.train()

    def two_classifier(self, epoch, train_loader_s, train_loader_t, writer, logger, rand_src_1, rand_src_2,
                       print_freq=1):
        self.reset_model_status()
        self.epoch = epoch
        self.scheduler.step(epoch)
        self.scheduler_fix.step(epoch)
        target_iter = iter(train_loader_t)
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()

        if (epoch < 80) or (110 <= epoch < 170):
            mode = 'normal_c1_c2'
        elif (80 <= epoch < 110) or (170 <= epoch < 210):
            mode = 'reverse_c1_c2'
        elif 210 <= epoch:
            mode = 'fix_c1_c2'
        for i, inputs in enumerate(train_loader_s):
            data_time.update(time.time() - end)
            try:
                inputs_target = next(target_iter)
            except:
                target_iter = iter(train_loader_t)
                inputs_target = next(target_iter)
            img_s, pid_s, camid_s = self._parse_data(inputs)
            img_t, pid_t, camid_t = self._parse_data(inputs_target)
            content_code_s, content_feat_s = self.Content_Encoder(img_s)
            pid_s_12 = np.asarray(pid_s.cpu())
            camid_s = np.asarray(camid_s.cpu())
            idx = []
            for c_id in rand_src_1:
                if len(np.where(c_id == camid_s)[0]) == 0:
                    continue
                else:
                    idx.append(np.where(c_id == camid_s)[0])
            if idx == [] or len(idx[0]) == 1:
                idx = [np.asarray([a]) for a in range(self.cfg.SOLVER.IMS_PER_BATCH)]
            idx = np.concatenate(idx)
            pid_1 = torch.tensor(pid_s_12[idx]).cuda()
            feat_1 = content_feat_s[idx]
            idx = []
            for c_id in rand_src_2:
                if len(np.where(c_id == camid_s)[0]) == 0:
                    continue
                else:
                    idx.append(np.where(c_id == camid_s)[0])
            if idx == [] or len(idx[0]) == 1:
                idx = [np.asarray([a]) for a in range(self.cfg.SOLVER.IMS_PER_BATCH)]
            idx = np.concatenate(idx)
            pid_2 = torch.tensor(pid_s_12[idx]).cuda()
            feat_2 = content_feat_s[idx]
            if mode == 'normal_c1_c2':
                class_1 = self.Content_Encoder(feat_1, mode='c1')
                class_2 = self.Content_Encoder(feat_2, mode='c2')
                ID_loss_1 = self.xent(class_1, pid_1)
                ID_loss_2 = self.xent(class_2, pid_2)
                ID_tri_loss = self.triplet(content_feat_s, pid_s)
                total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0]
                self.Content_optimizer.zero_grad()
                total_loss.backward()
                self.Content_optimizer.step()
                batch_time.update(time.time() - end)
                end = time.time()
                if (i + 1) % print_freq == 0:
                    logger.info('Epoch: [{}][{}/{}]\t'
                                'Time {:.3f} ({:.3f})\t'
                                'Data {:.3f} ({:.3f})\t'
                                'ID_loss: {:.3f}  ID_loss_1: {:.3f}  ID_loss_2: {:.3f}   tri_loss: {:.3f} '
                                .format(epoch, i + 1, len(train_loader_s),
                                        batch_time.val, batch_time.avg,
                                        data_time.val, data_time.avg,
                                        total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), ID_tri_loss[0].item()
                                        ))
            elif mode == 'reverse_c1_c2':
                class_1 = self.Content_Encoder(feat_1, mode='c2')
                class_2 = self.Content_Encoder(feat_2, mode='c1')
                ID_loss_1 = self.xent(class_1, pid_1)
                ID_loss_2 = self.xent(class_2, pid_2)
                ID_tri_loss = self.triplet(content_feat_s, pid_s)
                total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0]
                self.Content_optimizer_fix.zero_grad()
                total_loss.backward()
                self.Content_optimizer_fix.step()
                batch_time.update(time.time() - end)
                end = time.time()
                if (i + 1) % print_freq == 0:
                    logger.info('Epoch: [{}][{}/{}]\t'
                                'Time {:.3f} ({:.3f})\t'
                                'Data {:.3f} ({:.3f})\t'
                                'ID_loss: {:.3f}  ID_loss_1: {:.3f}  ID_loss_2: {:.3f}   tri_loss: {:.3f}'
                                .format(epoch, i + 1, len(train_loader_s),
                                        batch_time.val, batch_time.avg,
                                        data_time.val, data_time.avg,
                                        total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), ID_tri_loss[0].item()
                                        ))
            elif mode == 'fix_c1_c2':
                class_1 = self.Content_Encoder(feat_1, mode='c2')
                class_2 = self.Content_Encoder(feat_2, mode='c1')
                ID_loss_1 = self.xent(class_1, pid_1)
                ID_loss_2 = self.xent(class_2, pid_2)

                content_code_t, content_feat_t = self.Content_Encoder(img_t)
                tar_class_1 = self.Content_Encoder(content_feat_t, mode='c1')
                tar_class_2 = self.Content_Encoder(content_feat_t, mode='c2')
                tar_L1_loss = self.Smooth_L1_loss(tar_class_1, tar_class_2)
                ID_tri_loss = self.triplet(content_feat_s, pid_s)
                arg_c1 = torch.argmax(tar_class_1, dim=1)
                arg_c2 = torch.argmax(tar_class_2, dim=1)
                arg_idx = []
                fake_id = []
                for i_dx, data in enumerate(arg_c1):
                    if (data == arg_c2[i_dx]) and (((tar_class_1[i_dx][data] + tar_class_2[i_dx][arg_c2[i_dx]])/2) > 0.8):
                        arg_idx.append(i_dx)
                        fake_id.append(data)
                if 210 <= epoch < 220:
                    if arg_idx != []:
                        ID_loss_fake = self.xent(content_code_t[arg_idx], torch.tensor(fake_id).cuda())
                        total_loss = ID_loss_1 + ID_loss_2 + 0.5 * tar_L1_loss + ID_tri_loss[0]
                    else:
                        ID_loss_fake = torch.tensor([0])
                        total_loss = ID_loss_1 + ID_loss_2 + 0.5 * tar_L1_loss + ID_tri_loss[0]
                if 220 <= epoch:
                    if arg_idx != []:
                        ID_loss_fake = self.xent(content_code_t[arg_idx], torch.tensor(fake_id).cuda())
                        total_loss = ID_loss_1 + ID_loss_2 + 0.08 * ID_loss_fake + ID_tri_loss[0] + 0.5 * tar_L1_loss
                    else:
                        ID_loss_fake = torch.tensor([0])
                        total_loss = ID_loss_1 + ID_loss_2 + ID_tri_loss[0] + 0.5 * tar_L1_loss

                self.Content_optimizer_fix.zero_grad()
                total_loss.backward()
                self.Content_optimizer_fix.step()
                batch_time.update(time.time() - end)
                end = time.time()
                if (i + 1) % print_freq == 0:
                    logger.info('Epoch: [{}][{}/{}]\t'
                                'Time {:.3f} ({:.3f})\t'
                                'Data {:.3f} ({:.3f})\t'
                                'ID_loss: {:.3f}  ID_loss_1: {:.3f}  ID_loss_2: {:.3f}  tar_L1_loss: {:.3f}  tri_loss: {:.3f}  ID_loss_fake:  {:.6f}'
                                .format(epoch, i + 1, len(train_loader_s),
                                        batch_time.val, batch_time.avg,
                                        data_time.val, data_time.avg,
                                        total_loss.item(), ID_loss_1.item(), ID_loss_2.item(), tar_L1_loss.item(),
                                        ID_tri_loss[0].item(), ID_loss_fake.item()))
    def _parse_data(self, inputs):
        imgs, pids, camids = inputs
        inputs = imgs.cuda()
        targets = pids.cuda()
        camids = camids.cuda()
        return inputs, targets, camids
Beispiel #5
0
def main():

    torch.backends.cudnn.deterministic = True
    cudnn.benchmark = True
    #parser = argparse.ArgumentParser(description="ReID Baseline Training")
    #parser.add_argument(
    #"--config_file", default="", help="path to config file", type=str)

    #parser.add_argument("opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER)

    #args = parser.parse_args()
    config_file = 'configs/baseline_veri_r101_a.yml'
    if config_file != "":
        cfg.merge_from_file(config_file)
    #cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = setup_logger("reid_baseline", output_dir, if_train=True)
    logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR))
    logger.info(config_file)

    if config_file != "":
        logger.info("Loaded configuration file {}".format(config_file))
        with open(config_file, 'r') as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID

    path = 'D:/Python_SMU/Veri/verigms/gms/'
    pkl = {}
    entries = os.listdir(path)
    for name in entries:
        f = open((path + name), 'rb')
        if name == 'featureMatrix.pkl':
            s = name[0:13]
        else:
            s = name[0:3]
        pkl[s] = pickle.load(f)
        f.close

    with open('cids.pkl', 'rb') as handle:
        b = pickle.load(handle)

    with open('index.pkl', 'rb') as handle:
        c = pickle.load(handle)

    train_transforms, val_transforms, dataset, train_set, val_set = make_dataset(
        cfg, pkl_file='index.pkl')

    num_workers = cfg.DATALOADER.NUM_WORKERS
    num_classes = dataset.num_train_pids
    #pkl_f = 'index.pkl'
    pid = 0
    pidx = {}
    for img_path, pid, _, _ in dataset.train:
        path = img_path.split('\\')[-1]
        folder = path[1:4]
        pidx[folder] = pid
        pid += 1

    if 'triplet' in cfg.DATALOADER.SAMPLER:
        train_loader = DataLoader(train_set,
                                  batch_size=cfg.SOLVER.IMS_PER_BATCH,
                                  sampler=RandomIdentitySampler(
                                      dataset.train, cfg.SOLVER.IMS_PER_BATCH,
                                      cfg.DATALOADER.NUM_INSTANCE),
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  collate_fn=train_collate_fn)
    elif cfg.DATALOADER.SAMPLER == 'softmax':
        print('using softmax sampler')
        train_loader = DataLoader(train_set,
                                  batch_size=cfg.SOLVER.IMS_PER_BATCH,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  collate_fn=train_collate_fn)
    else:
        print('unsupported sampler! expected softmax or triplet but got {}'.
              format(cfg.SAMPLER))

    print("train loader loaded successfully")

    val_loader = DataLoader(val_set,
                            batch_size=cfg.TEST.IMS_PER_BATCH,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=True,
                            collate_fn=train_collate_fn)
    print("val loader loaded successfully")

    if cfg.MODEL.PRETRAIN_CHOICE == 'finetune':
        model = make_model(cfg, num_class=576)
        model.load_param_finetune(cfg.MODEL.PRETRAIN_PATH)
        print('Loading pretrained model for finetuning......')
    else:
        model = make_model(cfg, num_class=num_classes)

    loss_func, center_criterion = make_loss(cfg, num_classes=num_classes)

    optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion)
    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                  cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                  cfg.SOLVER.WARMUP_EPOCHS,
                                  cfg.SOLVER.WARMUP_METHOD)

    print("model,optimizer, loss, scheduler loaded successfully")

    height, width = cfg.INPUT.SIZE_TRAIN

    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD

    device = "cuda"
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info('start training')

    if device:
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for training'.format(
                torch.cuda.device_count()))
            model = nn.DataParallel(model)
        model.to(device)

    loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    evaluator = R1_mAP_eval(len(dataset.query),
                            max_rank=50,
                            feat_norm=cfg.TEST.FEAT_NORM)
    model.base._freeze_stages()
    logger.info('Freezing the stages number:{}'.format(cfg.MODEL.FROZEN))

    data_index = search(pkl)
    print("Ready for training")

    for epoch in range(1, epochs + 1):
        start_time = time.time()
        loss_meter.reset()
        acc_meter.reset()
        evaluator.reset()
        scheduler.step()
        model.train()
        for n_iter, (img, label, index, pid, cid) in enumerate(train_loader):
            optimizer.zero_grad()
            optimizer_center.zero_grad()
            #img = img.to(device)
            #target = vid.to(device)
            trainX, trainY = torch.zeros(
                (train_loader.batch_size * 3, 3, height, width),
                dtype=torch.float32), torch.zeros(
                    (train_loader.batch_size * 3), dtype=torch.int64)

            for i in range(train_loader.batch_size):
                labelx = label[i]
                indexx = index[i]
                cidx = pid[i]
                if indexx > len(pkl[labelx]) - 1:
                    indexx = len(pkl[labelx]) - 1

                a = pkl[labelx][indexx]
                minpos = np.argmin(ma.masked_where(a == 0, a))
                pos_dic = train_set[data_index[cidx][1] + minpos]
                #print(pos_dic[1])
                neg_label = int(labelx)

                while True:
                    neg_label = random.choice(range(1, 770))
                    if neg_label is not int(labelx) and os.path.isdir(
                            os.path.join('D:/datasets/veri-split/train',
                                         strint(neg_label))) is True:
                        break

                negative_label = strint(neg_label)
                neg_cid = pidx[negative_label]
                neg_index = random.choice(range(0, len(pkl[negative_label])))

                neg_dic = train_set[data_index[neg_cid][1] + neg_index]
                trainX[i] = img[i]
                trainX[i + train_loader.batch_size] = pos_dic[0]
                trainX[i + (train_loader.batch_size * 2)] = neg_dic[0]
                trainY[i] = cidx
                trainY[i + train_loader.batch_size] = pos_dic[3]
                trainY[i + (train_loader.batch_size * 2)] = neg_dic[3]

            #print(trainY)
            trainX = trainX.cuda()
            trainY = trainY.cuda()

            score, feat = model(trainX, trainY)
            loss = loss_func(score, feat, trainY)
            loss.backward()
            optimizer.step()
            if 'center' in cfg.MODEL.METRIC_LOSS_TYPE:
                for param in center_criterion.parameters():
                    param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT)
                optimizer_center.step()

            acc = (score.max(1)[1] == trainY).float().mean()
            loss_meter.update(loss.item(), img.shape[0])
            acc_meter.update(acc, 1)

            if (n_iter + 1) % log_period == 0:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                    .format(epoch, (n_iter + 1), len(train_loader),
                            loss_meter.avg, acc_meter.avg,
                            scheduler.get_lr()[0]))
        end_time = time.time()
        time_per_batch = (end_time - start_time) / (n_iter + 1)
        logger.info(
            "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
            .format(epoch, time_per_batch,
                    train_loader.batch_size / time_per_batch))

        if epoch % checkpoint_period == 0:
            torch.save(
                model.state_dict(),
                os.path.join(cfg.OUTPUT_DIR,
                             cfg.MODEL.NAME + '_{}.pth'.format(epoch)))

        if epoch % eval_period == 0:
            model.eval()
            for n_iter, (img, vid, camid, _, _) in enumerate(val_loader):
                with torch.no_grad():
                    img = img.to(device)
                    feat = model(img)
                    evaluator.update((feat, vid, camid))

            cmc, mAP, _, _, _, _, _ = evaluator.compute()
            logger.info("Validation Results - Epoch: {}".format(epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))
Beispiel #6
0
            model = nn.DataParallel(model)
        model.to(device)

    loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm='yes')
    model.base._freeze_stages()
    logger.info('Freezing the stages number:{}'.format(-1))
    # train
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        loss_meter.reset()
        acc_meter.reset()
        evaluator.reset()
        scheduler.step()
        model.train()
        for n_iter, (img, vid) in enumerate(train_loader):
            optimizer.zero_grad()
            optimizer_center.zero_grad()
            img = img.to(device)
            target = vid.to(device)

            feat = model(img, target)
            loss,score = loss_func(feat, target)

            loss.backward()
            optimizer.step()
            acc = (score.max(1)[1] == target).float().mean()
            loss_meter.update(loss.item(), img.shape[0])
            acc_meter.update(acc, 1)
Beispiel #7
0
def train(args):
    if args.batch_size % args.num_instance != 0:
        new_batch_size = (args.batch_size //
                          args.num_instance) * args.num_instance
        print(
            f"given batch size is {args.batch_size} and num_instances is {args.num_instance}."
            +
            f"Batch size must be divided into {args.num_instance}. Batch size will be replaced into {new_batch_size}"
        )
        args.batch_size = new_batch_size

    # prepare dataset
    train_loader, val_loader, num_query, train_data_len, num_classes = make_data_loader(
        args)

    model = build_model(args, num_classes)
    print("model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    loss_fn, center_criterion = make_loss(args, num_classes)
    optimizer, optimizer_center = make_optimizer(args, model, center_criterion)

    if args.cuda:
        model = model.cuda()
        if args.amp:
            if args.center_loss:
                model, [optimizer, optimizer_center] = \
                    amp.initialize(model, [optimizer, optimizer_center], opt_level="O1")
            else:
                model, optimizer = amp.initialize(model,
                                                  optimizer,
                                                  opt_level="O1")

        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        if args.center_loss:
            center_criterion = center_criterion.cuda()
            for state in optimizer_center.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()

    model_state_dict = model.state_dict()
    optim_state_dict = optimizer.state_dict()
    if args.center_loss:
        optim_center_state_dict = optimizer_center.state_dict()
        center_state_dict = center_criterion.state_dict()

    reid_evaluator = ReIDEvaluator(args, model, num_query)

    start_epoch = 0
    global_step = 0
    if args.pretrain != '':  # load pre-trained model
        weights = torch.load(args.pretrain)
        model_state_dict = weights["state_dict"]

        model.load_state_dict(model_state_dict)
        if args.center_loss:
            center_criterion.load_state_dict(
                torch.load(args.pretrain.replace(
                    'model', 'center_param'))["state_dict"])

        if args.resume:
            start_epoch = weights["epoch"]
            global_step = weights["global_step"]

            optimizer.load_state_dict(
                torch.load(args.pretrain.replace('model',
                                                 'optimizer'))["state_dict"])
            if args.center_loss:
                optimizer_center.load_state_dict(
                    torch.load(
                        args.pretrain.replace(
                            'model', 'optimizer_center'))["state_dict"])
        print(f'Start epoch: {start_epoch}, Start step: {global_step}')

    scheduler = WarmupMultiStepLR(optimizer, args.steps, args.gamma,
                                  args.warmup_factor, args.warmup_step,
                                  "linear",
                                  -1 if start_epoch == 0 else start_epoch)

    current_epoch = start_epoch
    best_epoch = 0
    best_rank1 = 0
    best_mAP = 0
    if args.resume:
        rank, mAP = reid_evaluator.evaluate(val_loader)
        best_rank1 = rank[0]
        best_mAP = mAP
        best_epoch = current_epoch + 1

    batch_time = AverageMeter()
    total_losses = AverageMeter()

    model_save_dir = os.path.join(args.save_dir, 'ckpts')
    os.makedirs(model_save_dir, exist_ok=True)

    summary_writer = SummaryWriter(log_dir=os.path.join(
        args.save_dir, "tensorboard_log"),
                                   purge_step=global_step)

    def summary_loss(score, feat, labels, top_name='global'):
        loss = 0.0
        losses = loss_fn(score, feat, labels)
        for loss_name, loss_val in losses.items():
            if loss_name.lower() == "accuracy":
                summary_writer.add_scalar(f"Score/{top_name}/triplet",
                                          loss_val, global_step)
                continue
            if "dist" in loss_name.lower():
                summary_writer.add_histogram(f"Distance/{loss_name}", loss_val,
                                             global_step)
                continue
            loss += loss_val
            summary_writer.add_scalar(f"losses/{top_name}/{loss_name}",
                                      loss_val, global_step)

        ohe_labels = torch.zeros_like(score)
        ohe_labels.scatter_(1, labels.unsqueeze(1), 1.0)

        cls_score = torch.softmax(score, dim=1)
        cls_score = torch.sum(cls_score * ohe_labels, dim=1).mean()
        summary_writer.add_scalar(f"Score/{top_name}/X-entropy", cls_score,
                                  global_step)

        return loss

    def save_weights(file_name, eph, steps):
        torch.save(
            {
                "state_dict": model_state_dict,
                "epoch": eph + 1,
                "global_step": steps
            }, file_name)
        torch.save({"state_dict": optim_state_dict},
                   file_name.replace("model", "optimizer"))
        if args.center_loss:
            torch.save({"state_dict": center_state_dict},
                       file_name.replace("model", "optimizer_center"))
            torch.save({"state_dict": optim_center_state_dict},
                       file_name.replace("model", "center_param"))

    # training start
    for epoch in range(start_epoch, args.max_epoch):
        model.train()
        t0 = time.time()
        for i, (inputs, labels, _, _) in enumerate(train_loader):
            if args.cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            cls_scores, features = model(inputs, labels)

            # losses
            total_loss = summary_loss(cls_scores[0], features[0], labels,
                                      'global')
            if args.use_local_feat:
                total_loss += summary_loss(cls_scores[1], features[1], labels,
                                           'local')

            optimizer.zero_grad()
            if args.center_loss:
                optimizer_center.zero_grad()

            # backward with global loss
            if args.amp:
                optimizers = [optimizer]
                if args.center_loss:
                    optimizers.append(optimizer_center)
                with amp.scale_loss(total_loss, optimizers) as scaled_loss:
                    scaled_loss.backward()
            else:
                with torch.autograd.detect_anomaly():
                    total_loss.backward()

            # optimization
            optimizer.step()
            if args.center_loss:
                for name, param in center_criterion.named_parameters():
                    try:
                        param.grad.data *= (1. / args.center_loss_weight)
                    except AttributeError:
                        continue
                optimizer_center.step()

            batch_time.update(time.time() - t0)
            total_losses.update(total_loss.item())

            # learning_rate
            current_lr = optimizer.param_groups[0]['lr']
            summary_writer.add_scalar("lr", current_lr, global_step)

            t0 = time.time()

            if (i + 1) % args.log_period == 0:
                print(
                    f"Epoch: [{epoch}][{i+1}/{train_data_len}]  " +
                    f"Batch Time {batch_time.val:.3f} ({batch_time.mean:.3f})  "
                    +
                    f"Total_loss {total_losses.val:.3f} ({total_losses.mean:.3f})"
                )
            global_step += 1

        print(
            f"Epoch: [{epoch}]\tEpoch Time {batch_time.sum:.3f} s\tLoss {total_losses.mean:.3f}\tLr {current_lr:.2e}"
        )

        if args.eval_period > 0 and (epoch + 1) % args.eval_period == 0 or (
                epoch + 1) == args.max_epoch:
            rank, mAP = reid_evaluator.evaluate(
                val_loader,
                mode="retrieval" if args.dataset_name == "cub200" else "reid")

            rank_string = ""
            for r in (1, 2, 4, 5, 8, 10, 16, 20):
                rank_string += f"Rank-{r:<3}: {rank[r-1]:.1%}"
                if r != 20:
                    rank_string += "    "
            summary_writer.add_text("Recall@K", rank_string, global_step)
            summary_writer.add_scalar("Rank-1", rank[0], (epoch + 1))

            rank1 = rank[0]
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_mAP = mAP
                best_epoch = epoch + 1

            if (epoch + 1) % args.save_period == 0 or (epoch +
                                                       1) == args.max_epoch:
                pth_file_name = os.path.join(
                    model_save_dir,
                    f"{args.backbone}_model_{epoch + 1}.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

            if is_best:
                pth_file_name = os.path.join(
                    model_save_dir, f"{args.backbone}_model_best.pth.tar")
                save_weights(pth_file_name, eph=epoch, steps=global_step)

        # end epoch
        current_epoch += 1

        batch_time.reset()
        total_losses.reset()
        torch.cuda.empty_cache()

        # update learning rate
        scheduler.step()

    print(f"Best rank-1 {best_rank1:.1%}, achived at epoch {best_epoch}")
    summary_writer.add_hparams(
        {
            "dataset_name": args.dataset_name,
            "triplet_dim": args.triplet_dim,
            "margin": args.margin,
            "base_lr": args.base_lr,
            "use_attn": args.use_attn,
            "use_mask": args.use_mask,
            "use_local_feat": args.use_local_feat
        }, {
            "mAP": best_mAP,
            "Rank1": best_rank1
        })