def main():
    args = get_args()
    train_loader, test_loader = load_data(args)

    if args.dataset_mode is "CIFAR10":
        num_classes = 10
    elif args.dataset_mode is "CIFAR100":
        num_classes = 100
    elif args.dataset_mode is "MNIST":
        num_classes = 10

    if args.load_pretrained:
        model = MobileNetV3().to(device)
        filename = "best_model_"
        checkpoint = torch.load('./checkpoint/' + filename + 'ckpt.t7')
        model.load_state_dict(checkpoint['model'])
        epoch = checkpoint['epoch']
        acc = checkpoint['acc']
        max_test_acc = acc
        print("Load Model Accuracy: ", acc, "Load Model end epoch: ", epoch)
    else:
        model = MobileNetV3(model_mode="LARGE", num_classes=num_classes).to(device)
        epoch = 1
        max_test_acc = 0
    # if device is "cuda":
    #     model = nn.DataParallel(model)
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=5e-4, momentum=0.9)
    criterion = nn.CrossEntropyLoss().to(device)

    if not os.path.isdir("reporting"):
        os.mkdir("reporting")

    start_time = time.time()
    for epoch in range(epoch, args.epochs):
        train(model, train_loader, optimizer, criterion, epoch, args)
        test_acc = get_test(model, test_loader)
        if max_test_acc < test_acc:
            print('Saving..')
            state = {
                'model': model.state_dict(),
                'acc': test_acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            filename = "best_model_"
            torch.save(state, './checkpoint/' + filename + 'ckpt.t7')
            max_test_acc = test_acc

        time_interval = time.time() - start_time
        time_split = time.gmtime(time_interval)
        print("Training time: ", time_interval, "Hour: ", time_split.tm_hour, "Minute: ", time_split.tm_min, "Second: ", time_split.tm_sec, end='')
        print(" Test acc:", max_test_acc, "time: ", time.time() - start_time)
        with open("./reporting/" + "best_model.txt", "w") as f:
            f.write("Epoch: " + str(epoch) + " " + "Best acc: " + str(max_test_acc) + "\n")
            f.write("Training time: " + str(time_interval) + "Hour: " + str(time_split.tm_hour) + "Minute: " + str(
                time_split.tm_min) + "Second: " + str(time_split.tm_sec))
            f.write("\n")
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser("parameters")

    parser.add_argument('--batch-size',
                        type=int,
                        default=16,
                        help='batch size, (default: 100)')
    parser.add_argument(
        '--dataset-mode',
        type=str,
        default="CIFAR100",
        help=
        "which dataset you use, (example: CIFAR10, CIFAR100), (default: CIFAR100)"
    )
    parser.add_argument(
        '--is-train',
        type=bool,
        default=False,
        help="True if training, False if test. (default: False)")
    parser.add_argument('--model-mode',
                        type=str,
                        default="LARGE",
                        help="(example: LARGE, SMALL), (default: LARGE)")

    args = parser.parse_args()

    _, test_loader = load_data(args)

    if args.dataset_mode == "CIFAR100":
        num_classes = 100
    elif args.dataset_mode == "CIFAR10":
        num_classes = 10

    if os.path.exists("./checkpoint"):
        model = MobileNetV3(model_mode=args.model_mode,
                            num_classes=num_classes).to(device)
        filename = "best_model_"
        checkpoint = torch.load('./checkpoint/' + filename + 'ckpt.t7')
        model.load_state_dict(checkpoint['model'])
        end_epoch = checkpoint['epoch']
        best_acc = checkpoint['acc']
        print("[Saved Best Accuracy]: ", best_acc, '%', "[End epochs]: ",
              end_epoch)
        print("Number of model parameters: ", get_model_parameters(model))

        model.eval()
        correct = 0
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            y_pred = output.data.max(1)[1]
            correct += y_pred.eq(target.data).sum()
        print("[Test Accuracy] ",
              100. * float(correct) / len(test_loader.dataset), '%')

    else:
        assert os.path.exists("./checkpoint/" + str(args.seed) +
                              "ckpt.t7"), "File not found. Please check again."
Beispiel #3
0
def main():
    homepath = os.environ['HOME']
    datapath = os.path.join(homepath, 'data')
    mx.file.copy_parallel(args.data_url, datapath)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MobileNetV3().to(device)
    centerloss = CenterLoss(num_classes=75, feat_dim=1280, use_gpu=True)
    cross_entropy = nn.CrossEntropyLoss()
    optimizer_model = torch.optim.SGD(model.parameters(),
                                      lr=args.lr_model,
                                      weight_decay=5e-04,
                                      momentum=0.9)
    optimizer_centloss = torch.optim.SGD(centerloss.parameters(),
                                         lr=args.lr_centloss)
    train_iterator, test_iterator = dataprocess(
        train_label_path=args.train_label_txt,
        data_dirtory=datapath,
        test_label_path=args.test_label_txt,
        batch_size=args.batch_size)

    if args.step > 0:
        scheduler = lr_scheduler.StepLR(optimizer_model,
                                        step_size=args.step,
                                        gamma=args.gamma)

    if not (os.path.isdir(os.path.join(args.homepath, 'model'))):
        os.makedirs(os.path.join(args.homepath, 'model'))
    tmp_accuracy = 0

    for epoch in range(args.num_epoch):
        if args.step > 0:
            scheduler.step()
        train_loss, train_acc = train(model=model,
                                      device=device,
                                      train_iterator=train_iterator,
                                      optimizer_model=optimizer_model,
                                      optimizer_centloss=optimizer_centloss,
                                      criterion1=cross_entropy,
                                      criterion2=centerloss,
                                      weight_centloss=args.weight)
        test_loss, test_acc = eval(model=model,
                                   device=device,
                                   test_iterator=test_iterator,
                                   criterion1=cross_entropy,
                                   criterion2=centerloss,
                                   weight_centloss=args.weight_centloss)
        print('|Epoch:', epoch + 1, '|Train loss',
              train_loss.item(), '|Train acc:', train_acc.item(), '|Test loss',
              test_loss.item(), '|Test acc', test_acc.item())
        if test_acc > tmp_accuracy:
            MODEL_SAVE_PATH = os.path.join(args.homepath, 'model',
                                           'mymodel_{}.pth'.format(epoch))
            torch.save(model.save_dict(), MODEL_SAVE_PATH)
            tmp_accuracy = test_acc
    mox.file.copy(MODEL_SAVE_PATH,
                  os.path.join(args.train_url, 'model/mymodel.pth'))
 def load_model(self):
     self.model = MobileNetV3(model_mode="SMALL",
                              num_classes=self.num_classes,
                              multiplier=1.0,
                              dropout_rate=1)
     if self.device:
         self.model.cuda()
     checkpoint = torch.load("./checkpoint/best_model_SMALL_ckpt.t7")
     self.model.load_state_dict(checkpoint['model'], strict=False)
     self.model.eval()
Beispiel #5
0
    def _prepare_model(self,
                       architecture: str,
                       head: str,
                       shallow_stride: int,
                       deep_stride: int,
                       width_mult: float = 1.0):
        mobile_net = MobileNetV3(architecture=architecture,
                                 width_mult=width_mult)

        shallow_hook_bool = False
        shallow_channels, deep_channels = None, None
        backbone = [mobile_net.features[0]]
        output_stride = max(backbone[0].conv_bn[0].stride)

        # append every inverted residual block to the backbone, until we reach
        # the last module with our desired output stride (deep_stride);
        # also apply a forward-hook to the last module with the desired
        # shallow stride, for a low-level feature skip connection
        for module in mobile_net.features.modules():
            if isinstance(module, InvertedResidual):
                output_stride *= module.stride

                if output_stride == shallow_stride * 2 \
                        and not shallow_hook_bool:
                    backbone[-1].register_forward_hook(
                        self._set_shallow_hook())
                    shallow_channels = backbone[-1].out_c
                    shallow_hook_bool = True
                if output_stride == deep_stride * 2 and shallow_hook_bool:
                    deep_channels = backbone[-1].out_c
                    break
                backbone.append(module)

        # if the very last module had the desired deep stride
        # extract its output channels
        if output_stride == deep_stride and shallow_hook_bool:
            deep_channels = backbone[-1].out_c

        assert shallow_channels, \
            "Shallow stride is to big, could not place hook!"
        assert deep_channels, \
            f"Deep stride is to big! Max stride possible {output_stride}"

        if head == "lr_aspp":
            # reduce channels in last block by factor of 2 and set dilation=2
            backbone[-1] = self._reduce_last_block_by_factor(
                backbone[-1], 2, 2)
            head = LR_ASPP(shallow_channels, deep_channels // 2, self.out_c,
                           self.head_c)
        if head == "aspp":
            head = ASPP(shallow_channels, deep_channels, self.out_c,
                        self.head_c)

        return nn.Sequential(*backbone), head
Beispiel #6
0
def main():
    args = get_args()
    # train_loader, test_loader = load_data(args)  # 返回迭代器
    # TODO: 加载自己的数据
    Batch_size = 32
    num_workers = 8
    annotation_path = './voice_data/train_data.txt'
    annotation_path1 = './voice_data/test_data.txt'
    with open(annotation_path) as f:
        lines = f.readlines()
    with open(annotation_path1) as f1:
        lines1 = f1.readlines()
    # TODO
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
    train_dataset = VoiceDataset(lines, (224, 224))
    test_dataset = VoiceDataset(lines1, (224, 224))
    train_loader = DataLoader(train_dataset,
                              batch_size=Batch_size,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True,
                              collate_fn=voice_dataset_collate)
    test_loader = DataLoader(test_dataset,
                             batch_size=Batch_size,
                             num_workers=num_workers,
                             pin_memory=True,
                             drop_last=True,
                             collate_fn=voice_dataset_collate)

    if args.dataset_mode == "CIFAR10":
        num_classes = 10
    elif args.dataset_mode == "CIFAR100":
        num_classes = 100
    elif args.dataset_mode == "IMAGENET":
        num_classes = 1000
    elif args.dataset_mode == "VOICE":
        num_classes = 3
    print('num_classes: ', num_classes)

    # TODO: 模型加载
    model = MobileNetV3(model_mode=args.model_mode,
                        num_classes=num_classes,
                        multiplier=args.multiplier,
                        dropout_rate=args.dropout).to(device)

    for para_tensor in model.state_dict():
        print(model.state_dict()[para_tensor].size())

    if torch.cuda.device_count() >= 1:
        print("num GPUs: ", torch.cuda.device_count())
        model = nn.DataParallel(model).to(device)

    # TODO: 是否做finetune
    if args.load_pretrained or args.evaluate:
        filename = "best_model_" + str(args.model_mode)
        checkpoint = torch.load('./checkpoint/' + filename + '_ckpt.t7')
        # # TODO: 将model中的module.去掉
        # if "state_dict" in model.keys():
        #     pretrained_dict = remove_prefix(model['state_dict'], 'module.')
        # else:
        #     pretrained_dict = remove_prefix(model, 'module.')
        # model.load_state_dict(pretrained_dict, strict=False)
        #
        # # TODO: 恢复权重
        # checkpoint['model'] = add_prefix(model, 'module.')
        model.load_state_dict(checkpoint['model'], strict=False)
        epoch = checkpoint['epoch']
        acc1 = checkpoint['best_acc1']
        # acc5 = checkpoint['best_acc5']
        best_acc1 = acc1
        # print("Load Model Accuracy1: ", acc1, " acc5: ", acc5, "Load Model end epoch: ", epoch)
        print("Load Model Accuracy1: ", acc1, "Load Model end epoch: ", epoch)
    else:
        print("init model load ...")
        epoch = 1
        best_acc1 = 0
    # TODO: 构建优化器和损失函数
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          weight_decay=1e-5,
                          momentum=0.9)
    # optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss().to(device)

    # TODO: 是否做验证
    if args.evaluate:
        acc1 = validate(test_loader, model, criterion, args)
        # acc1, acc5 = validate(test_loader, model, criterion, args)
        # print("Acc1: ", acc1, "Acc5: ", acc5)
        print("Acc1: ", acc1)
        return

    if not os.path.isdir("reporting"):
        os.mkdir("reporting")
    # TODO: 训练代码
    start_time = time.time()
    with open("./reporting/" + "best_model_" + args.model_mode + ".txt",
              "w") as f:
        for epoch in range(epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args)
            train(train_loader, model, criterion, optimizer, epoch, args)
            acc1 = validate(test_loader, model, criterion, args)
            # acc1, acc5 = validate(test_loader, model, criterion, args)

            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if is_best:
                print('Saving..')
                state = {
                    'model': model.state_dict(),
                    'best_acc1': best_acc1,
                    'epoch': epoch,
                }
                if not os.path.isdir('checkpoint'):
                    os.mkdir('checkpoint')
                filename = "best_model_" + str(args.model_mode)
                torch.save(state, './checkpoint/' + filename + '_ckpt.t7')

            time_interval = time.time() - start_time
            time_split = time.gmtime(time_interval)
            print("Training time: ",
                  time_interval,
                  "Hour: ",
                  time_split.tm_hour,
                  "Minute: ",
                  time_split.tm_min,
                  "Second: ",
                  time_split.tm_sec,
                  end='')
            print(" Test best acc1:", best_acc1, " acc1: ", acc1)

            f.write("Epoch: " + str(epoch) + " " + " Best acc: " +
                    str(best_acc1) + " Test acc: " + str(acc1) + "\n")
            f.write("Training time: " + str(time_interval) + " Hour: " +
                    str(time_split.tm_hour) + " Minute: " +
                    str(time_split.tm_min) + " Second: " +
                    str(time_split.tm_sec))
            f.write("\n")
Beispiel #7
0
def main():
    args = get_args()
    train_loader, test_loader = load_data(args)

    if args.dataset_mode == "CIFAR10":
        num_classes = 10
    elif args.dataset_mode == "CIFAR100":
        num_classes = 100
    elif args.dataset_mode == "IMAGENET":
        num_classes = 1000
    print('num_classes: ', num_classes)

    model = MobileNetV3(model_mode=args.model_mode, num_classes=num_classes, multiplier=args.multiplier, dropout_rate=args.dropout).to(device)
    if torch.cuda.device_count() >= 1:
        print("num GPUs: ", torch.cuda.device_count())
        model = nn.DataParallel(model).to(device)

    if args.load_pretrained or args.evaluate:
        filename = "best_model_" + str(args.model_mode) +str(args.prefix)
        try:
            dp_model = torch.load('./checkpoint/' + filename + '_dp_model.t7')
            model.load_state_dict(dp_model)
        except:
            #print(filename+"_dp_model.t7 is not found")
            checkpoint = torch.load('./checkpoint/' + filename + '_ckpt.t7')
            model.load_state_dict(checkpoint['model'])
        epoch = checkpoint['epoch']
        acc1 = checkpoint['best_acc1']
        acc5 = checkpoint['best_acc5']
        best_acc1 = acc1
        print("Load Model Accuracy1: ", acc1, " acc5: ", acc5, "Load Model end epoch: ", epoch)
    else:
        print("init model load ...")
        epoch = 1
        best_acc1 = 0
        
    if args.prune:
        print(f"Pruning {args.snip_percentage}% of weights with SNIP...")
        # get snip factor in form required for SNIP function
        snip_factor = (100 - args.snip_percentage)/100
        keep_masks = SNIP(model, snip_factor, train_loader, device)
        apply_prune_mask(model, keep_masks)

    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=1e-5, momentum=0.9)
    # optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss().to(device)

    if args.evaluate:
        acc1, acc5 = validate(test_loader, model, criterion, args)
        print("Acc1: ", acc1, "Acc5: ", acc5)
        return

    if not os.path.isdir("reporting"):
        os.mkdir("reporting")

    start_time = time.time()
    with open("./reporting/" + "best_model_" + args.model_mode + args.prefix + ".txt", "w") as f:
        for epoch in range(epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args)
            train(train_loader, model, criterion, optimizer, epoch, args)
            acc1, acc5 = validate(test_loader, model, criterion, args)

            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if is_best:
                print('Saving..')
                best_acc5 = acc5
                
                state = {
                    'model': model.state_dict(),
                    'best_acc1': best_acc1,
                    'best_acc5': best_acc5,
                    'epoch': epoch,
                }
                if not os.path.isdir('checkpoint'):
                    os.mkdir('checkpoint')
                filename = "best_model_" + str(args.model_mode)
                torch.save(state, './checkpoint/' + filename + args.prefix + '_ckpt.t7')
                torch.save(model.module.state_dict(), './checkpoint/' + filename + args.prefix + '_dp_model.t7')

            time_interval = time.time() - start_time
            time_split = time.gmtime(time_interval)
            print("Training time: ", time_interval, "Hour: ", time_split.tm_hour, "Minute: ", time_split.tm_min, "Second: ", time_split.tm_sec, end='')
            print(" Test best acc1:", best_acc1, " acc1: ", acc1, " acc5: ", acc5)

            f.write("Epoch: " + str(epoch) + " " + " Best acc: " + str(best_acc1) + " Test acc: " + str(acc1) + "\n")
            f.write("Training time: " + str(time_interval) + " Hour: " + str(time_split.tm_hour) + " Minute: " + str(
                time_split.tm_min) + " Second: " + str(time_split.tm_sec))
            f.write("\n")
Beispiel #8
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> checking if inform_data_file exists")
    if not os.path.isfile(args.inform_data_file):
        print("%s is not found" % (args.inform_data_file))
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        datas = dataCollect.collectDataAndSave()
        if datas is None:
            print("error while pickling data. Please check.")
            exit(-1)
    else:
        print("find file: ", str(args.inform_data_file))
        datas = pickle.load(open(args.inform_data_file, "rb"))

    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    #args.seed = random.randint(1, 10000)
    args.seed = 9830

    print("====> Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    cudnn.enabled = True

    model = MobileNetV3(model_mode="SMALL", num_classes=args.classes)

    network_type = "MobileNetV3"
    print("=====> current architeture:  MobileNetV3")

    print("=====> computing network parameters")
    total_paramters = netParams(model)
    print("the number of parameters: " + str(total_paramters))

    print("data['classWeights']: ", datas['classWeights'])
    print('=====> Dataset statistics')
    print('mean and std: ', datas['mean'], datas['std'])

    # define optimization criteria
    weight = torch.from_numpy(datas['classWeights'])
    criteria = CrossEntropyLoss2d(weight)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = torch.nn.DataParallel(
                model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + network_type + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    train_transform = transforms.Compose([transforms.ToTensor()])

    trainLoader = data.DataLoader(CityscapesDataSet(args.data_dir,
                                                    args.train_data_list,
                                                    crop_size=input_size,
                                                    scale=args.random_scale,
                                                    mirror=args.random_mirror,
                                                    mean=datas['mean']),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    valLoader = data.DataLoader(CityscapesValDataSet(args.data_dir,
                                                     args.val_data_list,
                                                     f_scale=1,
                                                     mean=datas['mean']),
                                batch_size=1,
                                shuffle=True,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)

    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            #model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
        logger.write("\nGlobal configuration as follows:")
        for key, value in vars(args).items():
            logger.write("\n{:16} {}".format(key, value))
        logger.write("\nParameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Global configuration as follows:")
        for key, value in vars(args).items():
            logger.write("\n{:16} {}".format(key, value))
        logger.write("\nParameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=5e-4)

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        #training
        lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model,
                                                     criteria, optimizer,
                                                     epoch)

        #validation
        if epoch % 50 == 0:
            mIOU_val, per_class_iu = val(args, valLoader, model, criteria)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_tr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f"
                % (epoch, lossTr, mIOU_tr, lr))

        #save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}
        if epoch > args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 20:
            torch.save(state, model_file_name)

    logger.close()
Beispiel #9
0
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)

    train_generator = DataGenerator(batch_size=args['batch_size'],
                                    root_dir='./new_dataset',
                                    csv_file='./new_dataset/face_mixed.csv',
                                    shuffle=True,
                                    transformer=True)
    val_generator = DataGenerator(batch_size=args['batch_size'],
                                  root_dir='./new_test_dataset',
                                  csv_file='./new_test_dataset/face_mixed.csv')

    # model = PFLDNetBackbone(input_shape=(112, 112, 3),
    #                         output_nodes=212, alpha=args['alpha'])
    model = MobileNetV3(shape=(112, 112, 3), n_class=212).build()

    if args['fine_tune']:
        model.load_weights(args['fine_tune_path'], by_name=True)

    # https://blog.csdn.net/laolu1573/article/details/83626555
    # we can samply set 'b2_s' in loss_weights to 0...
    model.compile(loss={
        'b1_s': wing_loss,
        'b2_s': smoothL1
    },
                  loss_weights={
                      'b1_s': 2,
                      'b2_s': 1
                  },
                  optimizer=Adam(lr=args['lr']),
Beispiel #10
0
import tensorflow as tf
import numpy as np
import argparse
from math import ceil as r
# imgage classification model
from model import MobileNetV3

# face detection model
from face.anchor_generator import generate_anchors
from face.anchor_decode import decode_bbox
from face.nms import single_class_non_max_suppression
from face.pytorch_loader import load_pytorch_model, pytorch_inference

# set imgage classification model
model_path = 'model/four_class_newdata.h5'
net = MobileNetV3.build_mobilenet()
net.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'])
net.build((1, 64, 64, 3))
net.load_weights(model_path)

# set face detection model
model = load_pytorch_model('face/model360.pth')
feature_map_sizes = [[45, 45], [23, 23], [12, 12], [6, 6], [4, 4]]
anchor_sizes = [[0.04, 0.056], [0.08, 0.11], [0.16, 0.22], [0.32, 0.45],
                [0.64, 0.72]]
anchor_ratios = [[1, 0.62, 0.42]] * 5
anchors = generate_anchors(feature_map_sizes, anchor_sizes, anchor_ratios)
anchors_exp = np.expand_dims(anchors, axis=0)
conf_thresh = 0.5
iou_thresh = 0.4
target_shape = (360, 360)
def test_func(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

        device = 'cuda'

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('=====> checking if processed cached_data_file exists')
    if not os.path.isfile(args.inform_data_file):
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.inform_data_file, "rb"))
    M = args.M
    N = args.N

    model = MobileNetV3(model_mode="SMALL", num_classes=args.classes)

    network_type = "MobileNetV3"
    print("Arch:  MobileNetV3")

    if args.cuda:
        model = model.to(device)  # using GPU for inference
        cudnn.benchmark = True

    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    # validation set
    testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet(
        args.data_dir, args.test_data_list, mean=data['mean']),
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint['model'])
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    print("=====> beginning testing")
    print("test set length: ", len(testLoader))
    test(args, testLoader, model, device, data)
def ValidateSegmentation(args):
    """
     main function for validation
     param args: global arguments
     return: None
    """
    print(args)
    global network_type

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or Wrong gpu id, please run without --cuda")

    args.seed = random.randint(1, 10000)
    print("Random Seed: ", args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('=====> checking if processed cached_data_file exists')
    if not os.path.isfile(args.inform_data_file):
        dataCollect = CityscapesTrainInform(
            args.data_dir,
            args.classes,
            train_set_file=args.dataset_list,
            inform_data_file=args.inform_data_file
        )  #collect mean std, weigth_class information
        data = dataCollect.collectDataAndSave()
        if data is None:
            print("error while pickling data, please check")
            exit(-1)
    else:
        data = pickle.load(open(args.inform_data_file, "rb"))

    model = MobileNetV3(model_mode="SMALL", num_classes=args.classes)

    network_type = "MobileNetV3"
    print("Arch:  MobileNetV3")
    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.cuda:
        weight = weight.cuda()
    criteria = CrossEntropyLoss2d(weight)  #weight

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        criteria = criteria.cuda()
        cudnn.benchmark = True

    print('Dataset statistics')
    print('mean and std: ', data['mean'], data['std'])
    print('classWeights: ', data['classWeights'])

    if args.save_seg_dir:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    # validation set
    valLoader = torch.utils.data.DataLoader(CityscapesValDataSet(
        args.data_dir, args.val_data_list, f_scale=1, mean=data['mean']),
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=====> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #model.load_state_dict(checkpoint['model'])
            model.load_state_dict(convert_state_dict(checkpoint['model']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    print("=====> beginning validation")
    print("validation set length: ", len(valLoader))
    val(args, valLoader, model, criteria)