def train(model, train_loader, optimizer, criterion, epoch, log_writer, args):
    train_loss = lib.Metric('train_loss')
    train_accuracy = lib.Metric('train_accuracy')
    model.train()
    N = len(train_loader)
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        lr_cur = adjust_learning_rate(args, optimizer, epoch, batch_idx, N, type=args.lr_scheduler)
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        train_loss.update(loss)
        train_accuracy.update(accuracy(output, target))
        if (batch_idx + 1) % 20 == 0:
            memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
            used_time = time.time() - start_time
            eta = used_time / (batch_idx + 1) * (N - batch_idx)
            eta = str(datetime.timedelta(seconds=int(eta)))
            training_state = '  '.join(['Epoch: {}', '[{} / {}]', 'eta: {}', 'lr: {:.9f}', 'max_mem: {:.0f}',
                                        'loss: {:.3f}', 'accuracy: {:.3f}'])
            training_state = training_state.format(epoch + 1, batch_idx + 1, N, eta, lr_cur, memory,
                                                   train_loss.avg.item(), 100. * train_accuracy.avg.item())
            print(training_state)

    if log_writer:
        log_writer.add_scalar('train/loss', train_loss.avg, epoch)
        log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
Beispiel #2
0
def train(epoch, net, trainloader, optimizer, npc, criterion, rlb, lr):
    train_loss = AverageMeter()
    net.train()
    adjust_learning_rate(optimizer, lr)
    for (inputs, _, indexes) in trainloader:
        optimizer.zero_grad()
        inputs, indexes = inputs.to(cfg.device), indexes.to(cfg.device)

        features = net(inputs)
        outputs = npc(features, indexes)
        loss = criterion(outputs, indexes, rlb)

        loss.backward()
        train_loss.update(loss.item(), inputs.size(0))

        optimizer.step()
    return train_loss.avg
Beispiel #3
0
def train(model, train_sampler, train_loader, optimizer, criterion, epoch,
          log_writer, args, verbose):
    train_loss = lib.Metric('train_loss')
    train_accuracy = lib.Metric('train_accuracy')
    model.train()
    if args.distributed:
        train_sampler.set_epoch(epoch)
    N = len(train_loader)
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        lr_cur = adjust_learning_rate(args,
                                      optimizer,
                                      epoch,
                                      batch_idx,
                                      N,
                                      type=args.lr_scheduler)
        if args.cuda:
            data, target = data.cuda(args.gpu, non_blocking=True), target.cuda(
                args.gpu, non_blocking=True)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        dist.all_reduce(loss)
        pred = output.max(1, keepdim=True)[1]
        acc = pred.eq(target.view_as(pred)).float().mean()
        dist.all_reduce(acc)
        train_loss.update(loss * 1.0 / args.ngpus_per_node)
        train_accuracy.update(acc.cpu() * 1.0 / args.ngpus_per_node)
        if (batch_idx + 1) % 20 == 0 and verbose:
            memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
            used_time = time.time() - start_time
            eta = used_time / (batch_idx + 1) * (N - batch_idx)
            eta = str(datetime.timedelta(seconds=int(eta)))
            training_state = '  '.join([
                'Epoch: {}', '[{} / {}]', 'eta: {}', 'lr: {:.9f}',
                'max_mem: {:.0f}', 'loss: {:.3f}', 'accuracy: {:.3f}'
            ])
            training_state = training_state.format(
                epoch + 1, batch_idx + 1, N, eta, lr_cur, memory,
                train_loss.avg.item(), 100. * train_accuracy.avg.item())
            print(training_state)

    if log_writer and verbose:
        log_writer.add_scalar('train/loss', train_loss.avg, epoch)
        log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
Beispiel #4
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    #  init seed
    my_whole_seed = 222
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedstart + 1):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           domain=args.domain,
                                           args=args)
        model = torch.nn.DataParallel(model).cuda()
        print('Number of learnable params',
              get_learnable_para(model) / 1000000., " M")

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        aug_test = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor(), normalize])

        # load dataset
        # import datasets.fundus_amd_syn_crossvalidation as medicaldata
        import datasets.fundus_amd_syn_crossvalidation_ind as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()

        if args.multitaskposrot:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        if args.multitaskposrot:
            print("running multi task with miccai")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        elif args.synthesis:
            print("running synthesis")
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running cvpr")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            f = open("savemodels/result.txt", "a+")
            f.write("auc: %.4f\n" % (auc))
            f.write("acc: %.4f\n" % (acc))
            f.write("pre: %.4f\n" % (precision))
            f.write("recall: %.4f\n" % (recall))
            f.write("f1score: %.4f\n" % (f1score))
            f.close()
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [1000, 2000])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, criterion,
                         cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # save checkpoint
            if epoch % 200 == 0 or (epoch in [1600, 1800, 2000]):
                auc, acc, precision, recall, f1score = kNN(
                    args, model, lemniscate, train_loader, val_loader, 100,
                    args.nce_t, 2)
                # save to txt
                writer.add_scalar("test_auc", auc, epoch)
                writer.add_scalar("test_acc", acc, epoch)
                writer.add_scalar("test_precision", precision, epoch)
                writer.add_scalar("test_recall", recall, epoch)
                writer.add_scalar("test_f1score", f1score, epoch)
                f = open(args.result + "/result.txt", "a+")
                f.write("epoch " + str(epoch) + "\n")
                f.write("auc: %.4f\n" % (auc))
                f.write("acc: %.4f\n" % (acc))
                f.write("pre: %.4f\n" % (precision))
                f.write("recall: %.4f\n" % (recall))
                f.write("f1score: %.4f\n" % (f1score))
                f.close()
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'lemniscate': lemniscate,
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=args.result + "/fold" + str(args.seedstart) +
                    "-epoch-" + str(epoch) + ".pth.tar")
def main():

    global args, best_prec1
    args = parser.parse_args()

    my_whole_seed = 222
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedend):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        from models.resnet_sup import resnet18, resnet50, resnet34
        model = resnet18()

        # pretrain_dict = torch.load("resnet18-5c106cde.pth")
        # model_dict = model.state_dict()
        # pretrained_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
        # pretrained_dict.pop("fc.weight")
        # pretrained_dict.pop("fc.bias")
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)

        model = torch.nn.DataParallel(model).cuda()
        model_weights = torch.load(
            "exp/fundus_dr/DR_miccai_repeat0/fold0-epoch-800.pth.tar")
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in model_weights["state_dict"].items() if k in model_dict
        }
        pretrained_dict.pop("module.fc.weight")
        pretrained_dict.pop("module.fc.bias")
        model.load_state_dict(pretrained_dict, strict=False)

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            # transforms.RandomGrayscale(p=0.2),
            # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        # aug = transforms.Compose([transforms.RandomRotation(60),
        #                           transforms.RandomResizedCrop(224, scale=(0.6, 1.)),
        #                           transforms.RandomGrayscale(p=0.2),
        #                           transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        #                           transforms.RandomHorizontalFlip(),
        #                           transforms.ToTensor(),
        #                             normalize])
        aug_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])

        # dataset
        import datasets.fundus_amd_syn_crossvalidation as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            worker_init_fn=random.seed(my_whole_seed))

        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                model_dict = model.state_dict()

                pretrained_dict = {
                    k: v
                    for k, v in checkpoint["state_dict"].items()
                    if k in model_dict
                }
                pretrained_dict.pop("module.fc.weight")
                pretrained_dict.pop("module.fc.bias")
                # pretrained_dict = {k: v for k, v in checkpoint["net"].items() if k in model_dict}
                # pretrained_dict.pop("module.conv1.weight")
                # pretrained_dict.pop("module.conv1.bias")

                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))

            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args,
                                      [500, 1000, 1500])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, criterion, optimizer)
            writer.add_scalar("train_loss", loss, epoch)

            gap_int = 200
            if (epoch) % gap_int == 0:
                loss_val, auc, acc, precision, recall, f1score = supervised_evaluation(
                    model, val_loader)
                writer.add_scalar("test_auc", auc, epoch)
                writer.add_scalar("test_acc", acc, epoch)
                writer.add_scalar("test_precision", precision, epoch)
                writer.add_scalar("test_recall", recall, epoch)
                writer.add_scalar("test_f1score", f1score, epoch)

                # save to txt
                f = open(args.result + "/result.txt", "a+")
                f.write("epoch " + str(epoch) + "\n")
                f.write("auc: %.4f\n" % (auc))
                f.write("acc: %.4f\n" % (acc))
                f.write("pre: %.4f\n" % (precision))
                f.write("recall: %.4f\n" % (recall))
                f.write("f1score: %.4f\n" % (f1score))
                f.close()

                # save checkpoint
            if epoch in [1000, 2000, 3000]:
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=args.result + "/epoch-" + str(epoch) + ".pth.tar")
Beispiel #6
0
    def train(self,
              nets,
              criterions,
              optimizers,
              train_loader,
              test_loader,
              logs=None,
              **kwargs):
        import time
        import os

        print("manual seed : %d" % self.args.manualSeed)

        for epoch in range(self.args.trainer.start_epoch,
                           self.args.trainer.epochs + 1):
            print("epoch %d" % epoch)
            start_time = time.time()

            for optimizer, model_args in zip(optimizers, self.args.models):
                utils.adjust_learning_rate(optimizer, epoch,
                                           model_args.optim.gammas,
                                           model_args.optim.schedule,
                                           model_args.optim.args.lr)

            kwargs = {} if kwargs is None else kwargs
            kwargs.update({
                "_trainer": self,
                "_train_loader": train_loader,
                "_test_loader": test_loader,
                "_nets": nets,
                "_criterions": criterions,
                "_optimizers": optimizers,
                "_epoch": epoch,
                "_logs": logs,
                "_args": self.args
            })

            # train for one epoch
            self.train_on_dataset(train_loader, nets, criterions, optimizers,
                                  epoch, logs, **kwargs)
            # evaluate on validation set
            self.validate_on_dataset(test_loader, nets, criterions, epoch,
                                     logs, **kwargs)

            # print log
            for i, log in enumerate(logs.net):
                print(
                    "  net{0}    loss :train={1:.3f}, test={2:.3f}    acc :train={3:.3f}, test ={4:.3f}"
                    .format(i, log["epoch_log"][epoch]["train_loss"],
                            log["epoch_log"][epoch]["test_loss"],
                            log["epoch_log"][epoch]["train_accuracy"],
                            log["epoch_log"][epoch]["test_accuracy"]))

            if epoch % self.args.trainer.saving_interval == 0:
                ckpt_dir = os.path.join(self.args.trainer.base_dir,
                                        "checkpoint")
                utils.save_checkpoint(nets, optimizers, epoch, ckpt_dir)

            logs.save(self.args.trainer.base_dir + r"log/")

            elapsed_time = time.time() - start_time
            print("  elapsed_time:{0:.3f}[sec]".format(elapsed_time))

            if "_callback" in kwargs:
                kwargs["_callback"](**kwargs)

        return
Beispiel #7
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    configs = './net/config.json'
    cfg = json.load(open(configs))

    # create dataset
    # -----------------------------------------------------------------------------------------------------
    trainloader, validloader = build_data_loader(cfg)
    anchors = None
    # create summary writer
    if not os.path.exists(config.log_dir):
        os.mkdir(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    model = SiameseAlexNet()
    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    # freeze layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        print("fixed layers:")
        print(model.featureExtract[:10])

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    for epoch in range(start_epoch, config.epoch + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        for i, data in enumerate(tqdm(trainloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target, delta_weight, gt = data
            # conf_target (8,1125) (8,225x5)
            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + i
            summary_writer.add_scalar('train/cls_loss', cls_loss.data, step)
            summary_writer.add_scalar('train/reg_loss', reg_loss.data, step)
            train_loss.append(loss.detach().cpu())
            loss_temp_cls += cls_loss.detach().cpu().numpy()
            loss_temp_reg += reg_loss.detach().cpu().numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)
            if (i + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, i, loss_temp_cls / config.show_interval,
                       loss_temp_reg / config.show_interval,
                       optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0

        train_loss = np.mean(train_loss)

        valid_loss = []
        model.eval()
        for i, data in enumerate(tqdm(validloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target, delta_weight, gt = data

            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)
        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")
            save_name = "./data/models/siamrpn_{}.pth".format(epoch)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))
Beispiel #8
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    meta_data_path = os.path.join(data_dir, "meta_data.pkl")
    meta_data = pickle.load(open(meta_data_path, 'rb'))
    all_videos = [x[0] for x in meta_data]

    # split train/valid dataset
    # -----------------------------------------------------------------------------------------------------
    train_videos, valid_videos = train_test_split(all_videos,
                                                  test_size=1 -
                                                  config.train_ratio,
                                                  random_state=config.seed)

    # define transforms
    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])
    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    # open lmdb
    db = lmdb.open(data_dir + '.lmdb', readonly=True, map_size=int(200e9))

    # create dataset
    # -----------------------------------------------------------------------------------------------------
    train_dataset = ImagnetVIDDataset(db, train_videos, data_dir,
                                      train_z_transforms, train_x_transforms)
    anchors = train_dataset.anchors
    # dic_num = {}
    # ind_random = list(range(len(train_dataset)))
    # import random
    # random.shuffle(ind_random)
    # for i in tqdm(ind_random):
    #     exemplar_img, instance_img, regression_target, conf_target = train_dataset[i+1000]

    valid_dataset = ImagnetVIDDataset(db,
                                      valid_videos,
                                      data_dir,
                                      valid_z_transforms,
                                      valid_x_transforms,
                                      training=False)
    # create dataloader
    trainloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size * torch.cuda.device_count(),
        shuffle=True,
        pin_memory=True,
        num_workers=config.train_num_workers * torch.cuda.device_count(),
        drop_last=True)
    validloader = DataLoader(
        valid_dataset,
        batch_size=config.valid_batch_size * torch.cuda.device_count(),
        shuffle=False,
        pin_memory=True,
        num_workers=config.valid_num_workers * torch.cuda.device_count(),
        drop_last=True)

    # create summary writer
    if not os.path.exists(config.log_dir):
        os.mkdir(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    model = SiameseAlexNet()
    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    # freeze layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        print("fixed layers:")
        print(model.featureExtract[:10])

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    for epoch in range(start_epoch, config.epoch + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        for i, data in enumerate(tqdm(trainloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target = data
            # conf_target (8,1125) (8,225x5)
            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + i
            summary_writer.add_scalar('train/cls_loss', cls_loss.data, step)
            summary_writer.add_scalar('train/reg_loss', reg_loss.data, step)
            train_loss.append(loss.detach().cpu())
            loss_temp_cls += cls_loss.detach().cpu().numpy()
            loss_temp_reg += reg_loss.detach().cpu().numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)
            if (i + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, i, loss_temp_cls / config.show_interval,
                       loss_temp_reg / config.show_interval,
                       optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0
                if vis_port:
                    anchors_show = train_dataset.anchors
                    exem_img = exemplar_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)
                    inst_img = instance_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)

                    # show detected box with max score
                    topk = config.show_topK
                    vis.plot_img(exem_img.transpose(2, 0, 1),
                                 win=1,
                                 name='exemple')
                    cls_pred = conf_target[0]
                    gt_box = get_topk_box(cls_pred, regression_target[0],
                                          anchors_show)[0]

                    # show gt_box
                    img_box = add_box_img(inst_img, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=2,
                                 name='instance')

                    # show anchor with max score
                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    scores, index = torch.topk(cls_pred, k=topk)
                    img_box = add_box_img(inst_img, anchors_show[index.cpu()])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=3,
                                 name='anchor_max_score')

                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    topk_box = get_topk_box(cls_pred,
                                            pred_offset[0],
                                            anchors_show,
                                            topk=topk)
                    img_box = add_box_img(inst_img, topk_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=4,
                                 name='box_max_score')

                    # show anchor and detected box with max iou
                    iou = compute_iou(anchors_show, gt_box).flatten()
                    index = np.argsort(iou)[-topk:]
                    img_box = add_box_img(inst_img, anchors_show[index])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=5,
                                 name='anchor_max_iou')

                    # detected box
                    regress_offset = pred_offset[0].cpu().detach().numpy()
                    topk_offset = regress_offset[index, :]
                    anchors_det = anchors_show[index, :]
                    pred_box = box_transform_inv(anchors_det, topk_offset)
                    img_box = add_box_img(inst_img, pred_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=6,
                                 name='box_max_iou')

        train_loss = np.mean(train_loss)

        valid_loss = []
        model.eval()
        for i, data in enumerate(tqdm(validloader)):
            exemplar_imgs, instance_imgs, regression_target, conf_target = data

            regression_target, conf_target = regression_target.cuda(
            ), conf_target.cuda()

            pred_score, pred_regression = model(exemplar_imgs.cuda(),
                                                instance_imgs.cuda())

            pred_conf = pred_score.reshape(
                -1, 2, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            pred_offset = pred_regression.reshape(
                -1, 4, config.anchor_num * config.score_size *
                config.score_size).permute(0, 2, 1)
            cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                 conf_target,
                                                 config.num_pos,
                                                 config.num_neg,
                                                 anchors,
                                                 ohem_pos=config.ohem_pos,
                                                 ohem_neg=config.ohem_neg)
            reg_loss = rpn_smoothL1(pred_offset,
                                    regression_target,
                                    conf_target,
                                    config.num_pos,
                                    ohem=config.ohem_reg)
            loss = cls_loss + config.lamb * reg_loss
            valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)
        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")
            save_name = "./data/models/siamrpn_{}.pth".format(epoch)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))
Beispiel #9
0
def train(round, epoch, net, trainloader, optimizer, npc, criterion,
          ANs_discovery, lr, writer):

    # tracking variables
    train_loss = AverageMeter()
    data_time = AverageMeter()
    batch_time = AverageMeter()

    # switch the model to train mode
    net.train()
    # adjust learning rate
    adjust_learning_rate(optimizer, lr)

    end = time.time()
    start_time = datetime.now()
    optimizer.zero_grad()
    for batch_idx, (inputs, _, indexes) in enumerate(trainloader):
        data_time.update(time.time() - end)
        inputs, indexes = inputs.to(cfg.device), indexes.to(cfg.device)

        features = net(inputs)
        outputs = npc(features, indexes)
        loss = criterion(outputs, indexes, ANs_discovery) / cfg.iter_size

        loss.backward()
        train_loss.update(loss.item() * cfg.iter_size, inputs.size(0))

        if batch_idx % cfg.iter_size == 0:
            optimizer.step()
            optimizer.zero_grad()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % cfg.display_freq != 0:
            continue

        writer.add_scalar('Train/Learning_Rate', lr,
                          epoch * len(trainloader) + batch_idx)
        writer.add_scalar('Train/Loss', train_loss.val,
                          epoch * len(trainloader) + batch_idx)

        elapsed_time, estimated_time = time_progress(batch_idx + 1,
                                                     len(trainloader),
                                                     batch_time.sum)
        logger.info(
            'Round: {round} Epoch: {epoch}/{tot_epochs} '
            'Progress: {elps_iters}/{tot_iters} ({elps_time}/{est_time}) '
            'Data: {data_time.sum:.3f} LR: {learning_rate:.5f} '
            'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f})'.format(
                round=round,
                epoch=epoch,
                tot_epochs=cfg.max_epoch,
                elps_iters=batch_idx,
                tot_iters=len(trainloader),
                elps_time=elapsed_time,
                est_time=estimated_time,
                data_time=data_time,
                learning_rate=lr,
                train_loss=train_loss))
Beispiel #10
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    my_whole_seed = 111
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedend):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           args=args)
        #
        # from models.Gresnet import ResNet18
        # model = ResNet18(low_dim=args.low_dim, multitask=args.multitask)
        model = torch.nn.DataParallel(model).cuda()

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        # aug = transforms.Compose([transforms.RandomRotation(60),
        #                           transforms.RandomResizedCrop(224, scale=(0.6, 1.)),
        #                           transforms.RandomGrayscale(p=0.2),
        #                           transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        #                           transforms.RandomHorizontalFlip(),
        #                           transforms.ToTensor(),
        #                             normalize])
        aug_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])

        # dataset
        import datasets.fundus_kaggle_dr as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 test_type="amd",
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_gon = medicaldata.traindataset(root=args.data,
                                                     transform=aug_test,
                                                     train=False,
                                                     test_type="gon",
                                                     args=args)
        val_loader_gon = torch.utils.data.DataLoader(
            valid_dataset_gon,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_pm = medicaldata.traindataset(root=args.data,
                                                    transform=aug_test,
                                                    train=False,
                                                    test_type="pm",
                                                    args=args)
        val_loader_pm = torch.utils.data.DataLoader(
            valid_dataset_pm,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        local_lemniscate = None

        if args.multitaskposrot:
            print("running multi task with positive")
            criterion = BatchCriterionRot(1, 0.1, args.batch_size, args).cuda()
        elif args.domain:
            print("running domain with four types--unify ")
            from lib.BatchAverageFour import BatchCriterionFour
            # criterion = BatchCriterionTriple(1, 0.1, args.batch_size, args).cuda()
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running multi task")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        if args.multitask:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [100, 200])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, local_lemniscate,
                         criterion, cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # gap_int = 10
            # if (epoch) % gap_int == 0:
            #     knn_num = 100
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2)
            #     writer.add_scalar("test_auc", auc, epoch)
            #     writer.add_scalar("test_acc", acc, epoch)
            #     writer.add_scalar("test_precision", precision, epoch)
            #     writer.add_scalar("test_recall", recall, epoch)
            #     writer.add_scalar("test_f1score", f1score, epoch)
            #
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_gon,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("gon/test_auc", auc, epoch)
            #     writer.add_scalar("gon/test_acc", acc, epoch)
            #     writer.add_scalar("gon/test_precision", precision, epoch)
            #     writer.add_scalar("gon/test_recall", recall, epoch)
            #     writer.add_scalar("gon/test_f1score", f1score, epoch)
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_pm,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("pm/test_auc", auc, epoch)
            #     writer.add_scalar("pm/test_acc", acc, epoch)
            #     writer.add_scalar("pm/test_precision", precision, epoch)
            #     writer.add_scalar("pm/test_recall", recall, epoch)
            #     writer.add_scalar("pm/test_f1score", f1score, epoch)

            # save checkpoint
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'optimizer': optimizer.state_dict(),
                },
                filename=args.result + "/fold" + str(args.seedstart) +
                "-epoch-" + str(epoch) + ".pth.tar")
Beispiel #11
0
def train(data_dir, model_path=None, vis_port=None, init=None):
    # loading meta data
    # -----------------------------------------------------------------------------------------------------
    meta_data_path = os.path.join(data_dir, "meta_data.pkl")
    meta_data = pickle.load(
        open(meta_data_path, 'rb')
    )  # meta_data[0] = ('ILSVRC2015_train_00001000', {0: ['000000', '000001', '000002',...]}),
    all_videos = [x[0] for x in meta_data]

    # split train/valid dataset
    # -----------------------------------------------------------------------------------------------------
    train_videos, valid_videos = train_test_split(all_videos,
                                                  test_size=1 -
                                                  config.train_ratio,
                                                  random_state=config.seed)
    print("after split:train_videos {0},valid_videos {1}".format(
        len(train_videos), len(valid_videos)))
    # define transforms
    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])
    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    # open lmdb
    # db = lmdb.open(data_dir + '_lmdb', readonly=True, map_size=int(1024*1024*1024)) # 200e9,单位Byte
    db_path = data_dir + '_Lmdb'
    # create dataset
    # -----------------------------------------------------------------------------------------------------
    train_dataset = ImagnetVIDDataset(db_path, train_videos, data_dir,
                                      train_z_transforms, train_x_transforms)
    # test __getitem__
    # train_dataset.__getitem__(1)
    # exit(0)

    anchors = train_dataset.anchors  # (1805,4) = (19*19*5,4)
    # dic_num = {}
    # ind_random = list(range(len(train_dataset)))
    # import random
    # random.shuffle(ind_random)
    # for i in tqdm(ind_random):
    #     exemplar_img, instance_img, regression_target, conf_target = train_dataset[i+1000]

    valid_dataset = ImagnetVIDDataset(db_path,
                                      valid_videos,
                                      data_dir,
                                      valid_z_transforms,
                                      valid_x_transforms,
                                      training=False)
    # create dataloader
    trainloader = DataLoader(train_dataset,
                             batch_size=config.train_batch_size,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=config.train_num_workers,
                             drop_last=True)
    validloader = DataLoader(valid_dataset,
                             batch_size=config.valid_batch_size,
                             shuffle=False,
                             pin_memory=True,
                             num_workers=config.valid_num_workers,
                             drop_last=True)

    # create summary writer
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    summary_writer = SummaryWriter(config.log_dir)
    if vis_port:
        vis = visual(port=vis_port)

    # start training
    # -----------------------------------------------------------------------------------------------------
    # model = SiameseAlexNet()
    model = SiamFPN50()
    model.init_weights()  # 权重初始化
    if config.CUDA:
        model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)
    # load model weight
    # -----------------------------------------------------------------------------------------------------
    start_epoch = 1
    if model_path and init:
        print("init training with checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        else:
            model_dict = model.state_dict()
            model_dict.update(checkpoint)
            model.load_state_dict(model_dict)
        del checkpoint
        torch.cuda.empty_cache()
        print("inited checkpoint")
    elif model_path and not init:
        print("loading checkpoint %s" % model_path + '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()
        print("loaded checkpoint")
    elif not model_path and config.pretrained_model:
        print("init with pretrained checkpoint %s" % config.pretrained_model +
              '\n')
        print(
            '------------------------------------------------------------------------------------------------ \n'
        )
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.state_dict()
        model_dict.update(checkpoint)
        model.load_state_dict(model_dict)

    #  layers
    def freeze_layers(model):
        print(
            '------------------------------------------------------------------------------------------------'
        )
        for layer in model.featureExtract[:10]:
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.Conv2d):
                for k, v in layer.named_parameters():
                    v.requires_grad = False
            elif isinstance(layer, nn.MaxPool2d):
                continue
            elif isinstance(layer, nn.ReLU):
                continue
            else:
                raise KeyError('error in fixing former 3 layers')
        # print("fixed layers:")
        # print(model.featureExtract[:10])
        '''
        fixed layers:
        Sequential(
        (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(2, 2))
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU(inplace)
        (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1))
        (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (7): ReLU(inplace)
        (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1))
        (9): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        '''

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)  # 前提是model已经.cuda()了
    # if isinstance(model,torch.nn.DataParallel): # 多GPU训练, AttributeError: ‘DataParallel’ object has no attribute ‘xxxx’
    #     model = model.module
    for epoch in range(start_epoch, config.EPOCH + 1):
        train_loss = []
        model.train()
        if config.fix_former_3_layers:  # 暂时去掉 # 固定前三层卷积的v.requires_grad = False
            if torch.cuda.device_count() > 1:
                freeze_layers(model.module)
            else:
                freeze_layers(model)
        loss_temp_cls = 0
        loss_temp_reg = 0
        loss_temp = 0
        # for i, data in enumerate(tqdm(trainloader)): # can't pickle Transaction objects
        for k, data in enumerate(tqdm(trainloader)):  # 这里有问题,loader没有遍历完就跳走了
            # print("done")
            # return
            # (8,3,127,127)\(8,3,271,271)\(8,1805,4)\(8,1805)
            # 8为batch_size,1445 = 19 * 19 * 5,5 = anchors_num
            # exemplar_imgs, instance_imgs, regression_target, conf_target = data
            exemplar_imgs, instance_imgs, regression_targets, conf_targets = data

            # conf_target (8,1125) (8,225x5)
            if config.CUDA:
                # 这里有问题,regression_targets是list,不能直接使用.cuda(),后面考虑将其压缩成(N,4)的形式
                # regression_targets, conf_targets = torch.tensor(regression_targets).cuda(), torch.tensor(conf_targets).cuda()
                exemplar_imgs, instance_imgs = exemplar_imgs.cuda(
                ), instance_imgs.cuda()

            # # 基于一层的损失计算
            # # (8,10,19,19)\(8,20,19,19)
            # pred_score, pred_regression = model(exemplar_imgs, instance_imgs)
            # # (8,1805,2)
            # pred_conf = pred_score.reshape(-1, 2, config.anchor_num * config.score_size * config.score_size).permute(0,2,1)
            # # (8,1805,4)
            # pred_offset = pred_regression.reshape(-1, 4,config.anchor_num * config.score_size * config.score_size).permute(0,2,1)

            # cls_loss = rpn_cross_entropy_balance(pred_conf, conf_target, config.num_pos, config.num_neg, anchors,
            #                                      ohem_pos=config.ohem_pos, ohem_neg=config.ohem_neg)
            # reg_loss = rpn_smoothL1(pred_offset, regression_target, conf_target, config.num_pos, ohem=config.ohem_reg)
            # loss = cls_loss + config.lamb * reg_loss
            # 基于金字塔模型的损失计算
            # try:
            #     output = model(input)
            # except RuntimeError as exception:
            #     if "out of memory" in str(exception):
            #         print("WARNING: out of memory")
            #         if hasattr(torch.cuda, 'empty_cache'):
            #             torch.cuda.empty_cache()
            #     else:
            #         raise exception

            pred_scores, pred_regressions = model(exemplar_imgs, instance_imgs)
            # FEATURE_MAP_SIZE、FPN_ANCHOR_NUM
            '''
            when batch_size = 2, anchor_num = 3
            torch.Size([N, 6, 37, 37])
            torch.Size([N, 6, 19, 19])
            torch.Size([N, 6, 10, 10])
            torch.Size([N, 6, 6, 6])

            torch.Size([N, 12, 37, 37])
            torch.Size([N, 12, 19, 19])
            torch.Size([N, 12, 10, 10])
            torch.Size([N, 12, 6, 6])
            '''
            loss = 0
            cls_loss_sum = 0
            reg_loss_sum = 0
            for i in range(len(pred_scores)):
                if i != 1:
                    continue  # 这里先只考虑一层(19*19)的损失,其余的暂时不考虑
                pred_score = pred_scores[i]
                pred_regression = pred_regressions[i]
                anchors_num = config.FPN_ANCHOR_NUM * config.FEATURE_MAP_SIZE[
                    i] * config.FEATURE_MAP_SIZE[i]
                pred_conf = pred_score.reshape(-1, 2,
                                               anchors_num).permute(0, 2, 1)
                pred_offset = pred_regression.reshape(-1, 4,
                                                      anchors_num).permute(
                                                          0, 2, 1)

                conf_target = conf_targets[i]
                regression_target = regression_targets[i].type(
                    torch.FloatTensor)  # pred_offset是float类型
                if config.CUDA:
                    conf_target = conf_target.cuda()
                    regression_target = regression_target.cuda()
                # 二分类损失计算(交叉熵)
                cls_loss = rpn_cross_entropy_balance(pred_conf,
                                                     conf_target,
                                                     config.num_pos,
                                                     config.num_neg,
                                                     anchors[i],
                                                     ohem_pos=config.ohem_pos,
                                                     ohem_neg=config.ohem_neg)
                # 回归损失计算(Smooth L1) # 这里应该有问题,回归损失的值为0
                reg_loss = rpn_smoothL1(pred_offset,
                                        regression_target,
                                        conf_target,
                                        config.num_pos,
                                        ohem=config.ohem_reg)

                _loss = cls_loss + reg_loss * config.lamb_reg  # config.lamb_cls

                loss += _loss  # 这里四层的loss先直接加起来,后面考虑加权处理
                # 用于tensorboard展示cls_loss\reg_loss 原样输出
                cls_loss_sum = cls_loss_sum + cls_loss
                reg_loss_sum = reg_loss_sum + reg_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
            optimizer.step()

            step = (epoch - 1) * len(trainloader) + k
            summary_writer.add_scalar('train/cls_loss', cls_loss_sum.data,
                                      step)
            summary_writer.add_scalar('train/reg_loss', reg_loss_sum.data,
                                      step)
            loss = loss.detach().cpu()
            train_loss.append(loss)
            loss_temp_cls += cls_loss_sum.detach().cpu().numpy()
            loss_temp_reg += reg_loss_sum.detach().cpu().numpy()
            loss_temp += loss.numpy()
            # if vis_port:
            #     vis.plot_error({'rpn_cls_loss': cls_loss.detach().cpu().numpy().ravel()[0],
            #                     'rpn_regress_loss': reg_loss.detach().cpu().numpy().ravel()[0]}, win=0)

            # print("Epoch {0} batch {1} training_loss:{2}".format(epoch, k+1, loss))

            if (k + 1) % config.show_interval == 0:
                tqdm.write(
                    "[epoch %2d][iter %4d] loss: %.4f, cls_loss: %.4f, reg_loss: %.4f lr: %.2e"
                    % (epoch, k + 1, loss_temp / config.show_interval,
                       loss_temp_cls / config.show_interval, loss_temp_reg /
                       config.show_interval, optimizer.param_groups[0]['lr']))
                loss_temp_cls = 0
                loss_temp_reg = 0
                loss_temp = 0
                # 视觉展示
                if vis_port:
                    anchors_show = train_dataset.anchors
                    exem_img = exemplar_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)
                    inst_img = instance_imgs[0].cpu().numpy().transpose(
                        1, 2, 0)

                    # show detected box with max score
                    topk = config.show_topK
                    vis.plot_img(exem_img.transpose(2, 0, 1),
                                 win=1,
                                 name='exemple')
                    cls_pred = conf_target[0]
                    gt_box = get_topk_box(cls_pred, regression_target[0],
                                          anchors_show)[0]

                    # show gt_box
                    img_box = add_box_img(inst_img, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=2,
                                 name='instance')

                    # show anchor with max score
                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    scores, index = torch.topk(cls_pred, k=topk)
                    img_box = add_box_img(inst_img, anchors_show[index.cpu()])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=3,
                                 name='anchor_max_score')

                    cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
                    topk_box = get_topk_box(cls_pred,
                                            pred_offset[0],
                                            anchors_show,
                                            topk=topk)
                    img_box = add_box_img(inst_img, topk_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=4,
                                 name='box_max_score')

                    # show anchor and detected box with max iou
                    iou = compute_iou(anchors_show, gt_box).flatten()
                    index = np.argsort(iou)[-topk:]
                    img_box = add_box_img(inst_img, anchors_show[index])
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=5,
                                 name='anchor_max_iou')

                    # detected box
                    regress_offset = pred_offset[0].cpu().detach().numpy()
                    topk_offset = regress_offset[index, :]
                    anchors_det = anchors_show[index, :]
                    pred_box = box_transform_inv(anchors_det, topk_offset)
                    img_box = add_box_img(inst_img, pred_box)
                    img_box = add_box_img(img_box, gt_box, color=(255, 0, 0))
                    vis.plot_img(img_box.transpose(2, 0, 1),
                                 win=6,
                                 name='box_max_iou')
        train_loss = np.mean(train_loss)
        # print("done")

        # exit(0)
        # 验证
        valid_loss = []
        # 不计算梯度,节约显存(验证阶段等价于测试,仅计算结果)
        with torch.no_grad():
            model.eval()
            # for i, data in enumerate(tqdm(validloader)):
            for i, data in enumerate(tqdm(validloader)):
                exemplar_imgs, instance_imgs, regression_targets, conf_targets = data
                if config.CUDA:
                    exemplar_imgs, instance_imgs = exemplar_imgs.cuda(
                    ), instance_imgs.cuda()

                pred_scores, pred_regressions = model(exemplar_imgs,
                                                      instance_imgs)
                loss = 0
                for i in range(len(pred_scores)):
                    if i != 1:
                        continue  # 这里先只考虑一层(19*19)的损失,其余的暂时不考虑
                    pred_score = pred_scores[i]
                    pred_regression = pred_regressions[i]
                    anchors_num = config.FPN_ANCHOR_NUM * config.FEATURE_MAP_SIZE[
                        i] * config.FEATURE_MAP_SIZE[i]
                    pred_conf = pred_score.reshape(-1, 2, anchors_num).permute(
                        0, 2, 1)
                    pred_offset = pred_regression.reshape(-1, 4,
                                                          anchors_num).permute(
                                                              0, 2, 1)

                    conf_target = conf_targets[i]
                    regression_target = regression_targets[i].type(
                        torch.FloatTensor)  # pred_offset是float类型
                    if config.CUDA:
                        conf_target = conf_target.cuda()
                        regression_target = regression_target.cuda()
                    # 二分类损失计算(交叉熵)
                    cls_loss = rpn_cross_entropy_balance(
                        pred_conf,
                        conf_target,
                        config.num_pos,
                        config.num_neg,
                        anchors[i],
                        ohem_pos=config.ohem_pos,
                        ohem_neg=config.ohem_neg)
                    # 回归损失计算(Smooth L1) # 这里应该有问题,回归损失的值为0
                    reg_loss = rpn_smoothL1(pred_offset,
                                            regression_target,
                                            conf_target,
                                            config.num_pos,
                                            ohem=config.ohem_reg)

                    _loss = cls_loss * config.lamb_cls + reg_loss * config.lamb_reg
                    loss += _loss  # 这里四层的loss先直接加起来,后面考虑加权处理
                valid_loss.append(loss.detach().cpu())
        valid_loss = np.mean(valid_loss)

        print("EPOCH %d valid_loss: %.4f, train_loss: %.4f" %
              (epoch, valid_loss, train_loss))
        summary_writer.add_scalar('valid/loss', valid_loss,
                                  (epoch + 1) * len(trainloader))
        # 调整学习率
        adjust_learning_rate(
            optimizer, config.gamma
        )  # adjust before save, and it will be epoch+1's lr when next load
        # 保存训练好的模型
        if epoch % config.save_interval == 0:
            if not os.path.exists('./data/models/'):
                os.makedirs("./data/models/")

            save_name = "./data/models/otb_siamfpn_{}_trainloss_{:.4f}_validloss_{:.4f}.pth".format(
                epoch, train_loss, valid_loss)
            new_state_dict = model.state_dict()
            if torch.cuda.device_count() > 1:
                new_state_dict = OrderedDict()
                for k, v in model.state_dict().items():
                    namekey = k[7:]  # remove `module.`
                    new_state_dict[namekey] = v
            torch.save(
                {
                    'epoch': epoch,
                    'model': new_state_dict,
                    'optimizer': optimizer.state_dict(),
                }, save_name)
            print('save model: {}'.format(save_name))

        # 清空缓存
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()