Example #1
0
def main():
    """
    Main function wrapper for demo script
    """

    random.seed(args["SEED"])
    np.random.seed(args["SEED"])
    torch.manual_seed(args["SEED"])
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    if args["TRAINED_WEIGHTS_FILE"] is not None:

        print("Trained Weights File: %s" % (args["TRAINED_WEIGHTS_FILE"]))
        print("Demo Directory: %s" % (args["DEMO_DIRECTORY"]))

        model = MyNet()
        model.load_state_dict(
            torch.load(
                args["CODE_DIRECTORY"] + args["TRAINED_WEIGHTS_FILE"],
                map_location=device,
            ))
        model.to(device)

        print("Running Demo ....")

        for root, dirs, files in os.walk(args["DEMO_DIRECTORY"]):
            for file in files:

                sampleFile = os.path.join(root, file)

                preprocess_sample(sampleFile)

                inp, _ = prepare_input(sampleFile)
                inputBatch = torch.unsqueeze(inp, dim=0)

                inputBatch = (inputBatch.float()).to(device)

                model.eval()
                with torch.no_grad():
                    outputBatch = model(inputBatch)

                predictionBatch = decode(outputBatch)
                pred = predictionBatch[0][:]

                print("File: %s" % (file))
                print("Prediction: %s" % (pred))
                print("\n")

        print("Demo Completed.")

    else:
        print("Path to trained weights file not specified.")

    return
Example #2
0
def main():
    """
    Main function wrapper for testing script.
    """

    random.seed(args["SEED"])
    np.random.seed(args["SEED"])
    torch.manual_seed(args["SEED"])
    if torch.cuda.is_available():
        device = torch.device("cuda")
        kwargs = {"num_workers": args["NUM_WORKERS"], "pin_memory": True}
    else:
        device = torch.device("cpu")
        kwargs = {}

    if args["TRAINED_WEIGHTS_FILE"] is not None:

        testData = MyDataset("test", datadir=args["DATA_DIRECTORY"])
        testLoader = DataLoader(testData,
                                batch_size=args["BATCH_SIZE"],
                                shuffle=True,
                                **kwargs)

        print("Trained Weights File: %s" % (args["TRAINED_WEIGHTS_FILE"]))

        model = MyNet()
        model.load_state_dict(
            torch.load(
                args["CODE_DIRECTORY"] + args["TRAINED_WEIGHTS_FILE"],
                map_location=device,
            ))
        model.to(device)

        criterion = MyLoss()
        regularizer = L2Regularizer(lambd=args["LAMBDA"])

        print("Testing the trained model ....")

        testLoss, testMetric = evaluate(model, testLoader, criterion,
                                        regularizer, device)

        print("| Test Loss: %.6f || Test Metric: %.3f |" %
              (testLoss, testMetric))
        print("Testing Done.")

    else:
        print("Path to the trained weights file not specified.")

    return
Example #3
0
    def load_model(self):
        """

        :return:
        """
        # TODO 1 加载模型
        use_cuda = self.use_cuda
        if self.o_net_path is not None:
            print('=======> loading')
            net = MyNet(use_cuda=False)
            net.load_state_dict(torch.load(self.o_net_path))
            if (use_cuda):
                net.to('cpu')
            net.eval()

        # TODO 2 准备好数据
        img_list = os.listdir(self.image_dir)
        for idx, item in enumerate(img_list):
            _img = Image.open(os.path.join(self.image_dir, item))
            parse_result = self.parse_image_name(item)
            landmark_and_format = parse_result['landmark_and_format']
            name = parse_result['name']
            img = self.transforms(_img)
            img = img.unsqueeze(0)

            pred = net(img)

            pred = pred * 192
            # pred = pred.detach().numpy()

            print('the pred landmark is :', pred)

            print("=" * 20)
            # # print(pred.shape)
            # # print(landmark)
            #
            try:
                self.save_pred(_img, name, landmark_and_format,
                               pred.detach().numpy())
            # self.visualize(_img, np.array(landmark))
            # self.visualize(_img, pred.detach().numpy())
            # # print(pred)
            except:
                print('Error:', item)
Example #4
0
def main():

    """
    Main function wrapper for training script.
    """

    matplotlib.use("Agg")
    random.seed(args["SEED"])
    np.random.seed(args["SEED"])
    torch.manual_seed(args["SEED"])
    if torch.cuda.is_available():
        device = torch.device("cuda")
        kwargs = {"num_workers": args["NUM_WORKERS"], "pin_memory": True}
    else:
        device = torch.device("cpu")
        kwargs = {}

    trainData = MyDataset("train", datadir=args["DATA_DIRECTORY"])
    valSize = int(args["VALIDATION_SPLIT"] * len(trainData))
    trainSize = len(trainData) - valSize
    trainData, valData = random_split(trainData, [trainSize, valSize])
    trainLoader = DataLoader(
        trainData, batch_size=args["BATCH_SIZE"], shuffle=True, **kwargs
    )
    valLoader = DataLoader(
        valData, batch_size=args["BATCH_SIZE"], shuffle=True, **kwargs
    )

    model = MyNet()
    model.to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=args["LEARNING_RATE"],
        betas=(args["MOMENTUM1"], args["MOMENTUM2"]),
    )
    scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer, gamma=args["LR_DECAY"]
    )
    criterion = MyLoss()
    regularizer = L2Regularizer(lambd=args["LAMBDA"])

    if os.path.exists(args["CODE_DIRECTORY"] + "/checkpoints"):
        while True:
            char = input(
                "Continue and remove the 'checkpoints' directory? y/n: "
            )
            if char == "y":
                break
            if char == "n":
                sys.exit()
            else:
                print("Invalid input")
        shutil.rmtree(args["CODE_DIRECTORY"] + "/checkpoints")

    os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints")
    os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/plots")
    os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/weights")

    if args["PRETRAINED_WEIGHTS_FILE"] is not None:
        print(
            "Pretrained Weights File: %s" % (args["PRETRAINED_WEIGHTS_FILE"])
        )
        print("Loading the pretrained weights ....")
        model.load_state_dict(
            torch.load(
                args["CODE_DIRECTORY"] + args["PRETRAINED_WEIGHTS_FILE"],
                map_location=device,
            )
        )
        model.to(device)
        print("Loading Done.")

    trainingLossCurve = list()
    validationLossCurve = list()
    trainingMetricCurve = list()
    validationMetricCurve = list()

    numTotalParams, numTrainableParams = num_params(model)
    print("Number of total parameters in the model = %d" % (numTotalParams))
    print(
        "Number of trainable parameters in the model = %d"
        % (numTrainableParams)
    )

    print("Training the model ....")

    for epoch in range(1, args["NUM_EPOCHS"] + 1):

        trainingLoss, trainingMetric = train(
            model, trainLoader, optimizer, criterion, regularizer, device
        )
        trainingLossCurve.append(trainingLoss)
        trainingMetricCurve.append(trainingMetric)

        validationLoss, validationMetric = evaluate(
            model, valLoader, criterion, regularizer, device
        )
        validationLossCurve.append(validationLoss)
        validationMetricCurve.append(validationMetric)

        print(
            (
                "| Epoch: %03d |"
                "| Tr.Loss: %.6f  Val.Loss: %.6f |"
                "| Tr.Metric: %.3f  Val.Metric: %.3f |"
            )
            % (
                epoch,
                trainingLoss, validationLoss,
                trainingMetric, validationMetric,
            )
        )

        scheduler.step()

        if epoch % args["SAVE_FREQUENCY"] == 0:

            savePath = (
                args["CODE_DIRECTORY"]
                + "/checkpoints/weights/epoch_{:04d}-metric_{:.3f}.pt"
            ).format(epoch, validationMetric)
            torch.save(model.state_dict(), savePath)

            plt.figure()
            plt.title("Loss Curves")
            plt.xlabel("Epoch No.")
            plt.ylabel("Loss value")
            plt.plot(
                list(range(1, len(trainingLossCurve) + 1)),
                trainingLossCurve,
                "blue",
                label="Train",
            )
            plt.plot(
                list(range(1, len(validationLossCurve) + 1)),
                validationLossCurve,
                "red",
                label="Validation",
            )
            plt.legend()
            plt.savefig(
                (
                    args["CODE_DIRECTORY"]
                    + "/checkpoints/plots/epoch_{:04d}_loss.png"
                ).format(epoch)
            )
            plt.close()

            plt.figure()
            plt.title("Metric Curves")
            plt.xlabel("Epoch No.")
            plt.ylabel("Metric")
            plt.plot(
                list(range(1, len(trainingMetricCurve) + 1)),
                trainingMetricCurve,
                "blue",
                label="Train",
            )
            plt.plot(
                list(range(1, len(validationMetricCurve) + 1)),
                validationMetricCurve,
                "red",
                label="Validation",
            )
            plt.legend()
            plt.savefig(
                (
                    args["CODE_DIRECTORY"]
                    + "/checkpoints/plots/epoch_{:04d}_metric.png"
                ).format(epoch)
            )
            plt.close()

    print("Training Done.")

    return
Example #5
0
def train(**kwargs):
    # 1. configure model
    cfg._parse(kwargs)
    model = MyNet()
    if cfg.load_model_path:
        model.load_state_dict(torch.load(cfg.load_model_path))

    if cfg.multi_gpu:
        model = parallel.DataParallel(model)
    
    if cfg.use_gpu:
        model.cuda()
    
    
    # 2. prepare data
    train_data = SN(root=cfg.train_data_root, crop_size=cfg.crop_size)
    train_loader = DataLoader(train_data, batch_size=cfg.batch_size, shuffle=True)

    # 3. criterion (already imported) and optimizer
    lr = cfg.lr
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=cfg.weight_decay)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.momentum)

    # 4. meters
    loss_meter = meter.AverageValueMeter()
    previous_loss = 1e10

    # train
    for epoch in range(cfg.max_epoch):
        print('epoch %s: ===========================' % epoch)
        loss_meter.reset()

        for ii, (data, label_group) in tqdm(enumerate(train_loader)):
            # train model
            if cfg.use_gpu:
                data = data.cuda()
                label_group = [label.cuda() for label in label_group]
            data = Variable(data).float()
            label_group = [Variable(label) for label in label_group]
           
            optimizer.zero_grad()
            score = model(data)
            # for item in score:
            #     print(item)
            loss = criterion(score, label_group, batch_size=cfg.batch_size, neg_pos_ratio=cfg.neg_pos_ratio)
            loss.backward()
            optimizer.step()

            # meters update and print
            loss_meter.add(loss.item())
            if (ii + 1) % cfg.print_freq == 0:
                print(loss_meter.value()[0])
        
        if (epoch + 1) % cfg.save_freq == 0:
            torch.save(model.module.state_dict(), f'./checkpoints/last.pth')
        
        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            lr = lr * cfg.lr_decay
            # 第二种降低学习率的方法:不会有moment等信息的丢失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]