コード例 #1
0
def main():
    args = get_arguments()

    gpus = [int(i) for i in args.gpu.split(',')]
    assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    print("Evaluating total class number {} with {}".format(num_classes, label))

    model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)

    state_dict = torch.load(os.path.join(checkpoints_path, ckpt_choice[ity]))['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        k = k[7:]
        ss = k.split('.')
        if ss[-2].startswith('bn') and ss[-1].endswith('weight'):
            v1 = torch.abs(v) + eps
        else:
            v1 = v
        new_state_dict[k] = v1
    model.load_state_dict(new_state_dict)
    model.cuda()
    model.eval()
    torch.save(model.state_dict(), '/home/qiu/Projects/Self-Correction-Human-Parsing/deploy/'+ity+'_abn_checkpoint.pth')
コード例 #2
0
def main():
    args = get_arguments()

    gpus = [int(i) for i in args.gpu.split(',')]
    assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    print("Evaluating total class number {} with {}".format(
        num_classes, label))

    model = networks.init_model('resnet101',
                                num_classes=num_classes,
                                pretrained=None)

    state_dict = torch.load(args.model_restore)['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.cuda()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    dataset = SimpleFolderDataset(root=args.input_dir,
                                  input_size=input_size,
                                  transform=transform)
    dataloader = DataLoader(dataset)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    palette = get_palette(num_classes)
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(dataloader)):
            image, meta = batch
            img_name = meta['name'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]

            output = model(image.cuda())
            upsample = torch.nn.Upsample(size=input_size,
                                         mode='bilinear',
                                         align_corners=True)
            upsample_output = upsample(output[0][-1][0].unsqueeze(0))
            upsample_output = upsample_output.squeeze()
            upsample_output = upsample_output.permute(1, 2, 0)  # CHW -> HWC

            logits_result = transform_logits(
                upsample_output.data.cpu().numpy(),
                c,
                s,
                w,
                h,
                input_size=input_size)
            parsing_result = np.argmax(logits_result, axis=2)
            parsing_result_path = os.path.join(args.output_dir,
                                               img_name[:-4] + '.png')
            output_img = Image.fromarray(
                np.asarray(parsing_result, dtype=np.uint8))
            output_img.putpalette(palette)
            output_img.save(parsing_result_path)
            if args.logits:
                logits_result_path = os.path.join(args.output_dir,
                                                  img_name[:-4] + '.npy')
                np.save(logits_result_path, logits_result)
    return
コード例 #3
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()
    multi_scales = [float(i) for i in args.multi_scales.split(',')]
    gpus = [int(i) for i in args.gpu.split(',')]
    assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True
    cudnn.enabled = True

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    model = networks.init_model(args.arch,
                                num_classes=args.num_classes,
                                pretrained=None)

    IMAGE_MEAN = model.mean
    IMAGE_STD = model.std
    INPUT_SPACE = model.input_space
    print('image mean: {}'.format(IMAGE_MEAN))
    print('image std: {}'.format(IMAGE_STD))
    print('input space:{}'.format(INPUT_SPACE))
    if INPUT_SPACE == 'BGR':
        print('BGR Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])
    if INPUT_SPACE == 'RGB':
        print('RGB Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            BGR2RGB_transform(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])

    # Data loader
    lip_test_dataset = LIPDataValSet(args.data_dir,
                                     'val',
                                     crop_size=input_size,
                                     transform=transform,
                                     flip=args.flip)
    num_samples = len(lip_test_dataset)
    print('Totoal testing sample numbers: {}'.format(num_samples))
    testloader = data.DataLoader(lip_test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 pin_memory=True)

    # Load model weight
    state_dict = torch.load(args.model_restore)['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.cuda()
    model.eval()

    sp_results_dir = os.path.join(args.log_dir, 'sp_results')
    if not os.path.exists(sp_results_dir):
        os.makedirs(sp_results_dir)

    palette = get_palette(20)
    parsing_preds = []
    scales = np.zeros((num_samples, 2), dtype=np.float32)
    centers = np.zeros((num_samples, 2), dtype=np.int32)
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(testloader)):
            image, meta = batch
            if (len(image.shape) > 4):
                image = image.squeeze()
            im_name = meta['name'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]
            scales[idx, :] = s
            centers[idx, :] = c
            parsing, logits = multi_scale_testing(model,
                                                  image.cuda(),
                                                  crop_size=input_size,
                                                  flip=args.flip,
                                                  multi_scales=multi_scales)
            if args.save_results:
                parsing_result = transform_parsing(parsing, c, s, w, h,
                                                   input_size)
                parsing_result_path = os.path.join(sp_results_dir,
                                                   im_name + '.png')
                output_im = PILImage.fromarray(
                    np.asarray(parsing_result, dtype=np.uint8))
                output_im.putpalette(palette)
                output_im.save(parsing_result_path)

            parsing_preds.append(parsing)
    assert len(parsing_preds) == num_samples
    mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes,
                            args.data_dir, input_size)
    print(mIoU)
    return
コード例 #4
0
import torchvision.transforms as transforms
import networks
from utils.transforms import transform_logits
from datasets.simple_extractor_dataset import SimpleFolderDataset

from functions import *

radius_fraction = 0.06
alpha = 0.4

images_folder_path = './images/'
output_path = './output/'

num_classes = 7
model = networks.init_model('resnet101',
                            num_classes=num_classes,
                            pretrained=None)
state_dict = torch.load(
    './weights/exp-schp-201908270938-pascal-person-part.pth')['state_dict']
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()


def get_palette(num_cls):
コード例 #5
0
ファイル: demo_video.py プロジェクト: TannedCung/SCHP
def main():
    args = get_arguments()

    # gpus = [int(i) for i in args.gpu.split(',')]
    # assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    print("Evaluating total class number {} with {}".format(
        num_classes, label))

    model = networks.init_model('resnet101',
                                num_classes=num_classes,
                                pretrained=None)

    state_dict = torch.load(args.model_restore)  #['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    secretary = change_dict.dictModify(new_dict=new_state_dict,
                                       old_dict=state_dict)
    new_state_dict = secretary.arange()
    # for k, v in state_dict.items():
    #     name = k[7:]  # remove `module.`
    #     new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    # print(model)
    #  # model.cuda()
    # model = torch.load("log/entire_model.pth", map_location=torch.device('cpu'))
    model.eval()
    input = torch.randn(1, 3, 473, 473)
    ONNX_FILE_PATH = "pretrain_model/local.onnx"
    ONNX_SIM_FILE_PATH = "pretrain_model/sim_local.onnx"

    # torch.onnx.export(model, input, ONNX_FILE_PATH, input_names=['input'], output_names=['output'], opset_version=11) #, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
    # onnx_model = onnx.load(ONNX_FILE_PATH)
    # sim_model, check = onnxsim.simplify(onnx_model)
    # onnx.save(sim_model, ONNX_SIM_FILE_PATH)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    transformer = SimpleVideo(transforms=transform, input_size=[473, 473])

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    color_man = gray2color.makeColor(num_classes)
    VIDEO_PATH = "input.mp4"
    cap = cv2.VideoCapture(VIDEO_PATH)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    out = cv2.VideoWriter('outpy.avi',
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 24,
                          (frame_width, frame_height))
    with torch.no_grad():
        start = time.time()
        count = 0
        ret = True
        while (cap.isOpened()):
            ret, frame = cap.read()
            if ret == True:
                frame, meta = transformer.get_item(frame)
                c = meta['center']
                s = meta['scale']
                w = meta['width']
                h = meta['height']
                # out_put = model(frame)
                output = model(frame)
                # output = model(image.cuda())
                upsample = torch.nn.Upsample(size=input_size,
                                             mode='bilinear',
                                             align_corners=True)
                upsample_output = upsample(output[0][-1][0].unsqueeze(0))
                upsample_output = upsample_output.squeeze()
                upsample_output = upsample_output.permute(1, 2,
                                                          0)  # CHW -> HWC

                logits_result = transform_logits(
                    upsample_output.data.cpu().numpy(),
                    c,
                    s,
                    w,
                    h,
                    input_size=input_size)
                parsing_result = np.argmax(logits_result, axis=2)
                output_img = Image.fromarray(
                    np.asarray(parsing_result, dtype=np.uint8))
                out_img = np.array(output_img)
                output_img = color_man.G2C(out_img)

                out.write(np.array(output_img))
                cv2.imshow("Tanned", np.array(output_img))
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

                if args.logits:
                    logits_result_path = os.path.join(args.output_dir,
                                                      img_name[:-4] + '.npy')
                    np.save(logits_result_path, logits_result)
                count += 1
            else:
                break
        end = time.time()
        cap.release()
        out.release()
        print(
            "Processed {} images using {:.5} seconds, average each image took {:.5} seconds"
            .format(count, end - start, (end - start) / (count + 0.1)))
    return
コード例 #6
0
def main():
    args = get_arguments()

    gpus = [int(i) for i in args.gpu.split(',')]
    assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    print("Evaluating total class number {} with {}".format(num_classes, label))

    model = networks.init_model('resnet101_1', num_classes=num_classes, pretrained=None)

    state_dict = torch.load(args.model_restore)['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # model.load_state_dict(state_dict)
    model.cuda()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])

    img_path = '/home/qiu/Projects/Self-Correction-Human-Parsing/run/garmin-forerunner-245-what-to-expect.jpg'
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    h, w, _ = img.shape
    # Get person center and scale
    person_center, s = _box2cs([0, 0, w - 1, h - 1])
    r = 0
    trans = get_affine_transform(person_center, s, r, input_size)
    input = cv2.warpAffine(
        img,
        trans,
        (int(input_size[1]), int(input_size[0])),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(0, 0, 0))
    cv2.imwrite('input.png', input)
    x = transform(input)
    x = x.unsqueeze(0)

    dummy_input = torch.rand(1, 3, 512, 512).cuda()
    # x = torch.rand(1, 3, 512, 512)
    # model = model.eval()

    unscripted_output = model(x.cuda())  # Get the unscripted model's prediction...
    # unscripted_top5 = F.softmax(unscripted_output, dim=1).topk(5).indices
    # print('Python model top 5 results:\n  {}'.format(unscripted_top5))
    # script_model=torch.jit.script(model)
    traced_model = torch.jit.trace(model, dummy_input)

    scripted_output = traced_model(x.cuda())  # ...and do the same for the scripted version
    output = unscripted_output
    upsample = torch.nn.Upsample(size=input_size,
                                 mode='bilinear',
                                 align_corners=True)
    # upsample_output = upsample(output[0][-1][0].unsqueeze(0))
    upsample_output = upsample(output[0].unsqueeze(0))
    # upsample_output = output[0].unsqueeze(0)
    upsample_output = upsample_output.squeeze()
    upsample_output = upsample_output.permute(1, 2, 0)  # CHW -> HWC
    logits_result = upsample_output.data.cpu().numpy()
    parsing_result = np.argmax(logits_result, axis=2)
    cv2.imwrite('./ps.png', parsing_result * 255)
    cv2.imshow('ps', parsing_result * 255)
    cv2.waitKey(-1)
    # scripted_top5 = F.softmax(scripted_output, dim=1).topk(5).indices

    # print('TorchScript model top 5 results:\n  {}'.format(scripted_top5))
    traced_model.save(ity + '_abn.pt')
    print(torch.where(torch.eq(unscripted_output, scripted_output) != 1))
コード例 #7
0
def main():
    args = get_arguments()

    # gpus = [int(i) for i in args.gpu.split(',')]
    # assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    # Model in torch init
    model = networks.init_model('resnet101',
                                num_classes=num_classes,
                                pretrained=None)

    state_dict = torch.load(args.model_restore)  #['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    secretary = change_dict.dictModify(new_dict=new_state_dict,
                                       old_dict=state_dict)
    new_state_dict = secretary.arange()
    model.load_state_dict(new_state_dict)

    model.eval()

    print("Evaluating total class number {} with {}".format(
        num_classes, label))

    sess = onnxruntime.InferenceSession("pretrain_model/sim_local.onnx")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    transformer = SimpleVideo(transforms=transform, input_size=[473, 473])

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    color_man = gray2color.makeColor(num_classes)
    VIDEO_PATH = "input.mp4"
    cap = cv2.VideoCapture(VIDEO_PATH)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    out = cv2.VideoWriter('outpy.avi',
                          cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 24,
                          (frame_width * 3, frame_height))
    with torch.no_grad():
        start = time.time()
        count = 0
        ret = True
        while (cap.isOpened()):
            ret, frame = cap.read()
            if ret == True:
                # w, h, _ = frame.shape
                # c, s, w, h = cal_chwsr(0, 0, w, h)
                # frame = torch.tensor(frame)
                print("frame: {}".format(frame.shape))
                frame, meta = transformer.get_item(frame)
                c = meta['center']
                s = meta['scale']
                w = meta['width']
                h = meta['height']

                # out_put = model(frame)
                # input_name = sess.get_inputs()[0].name
                # print("input name:", input_name)
                # input_shape = sess.get_inputs()[0].shape
                # print("input shape:", input_shape)
                # input_type = sess.get_inputs()[0].type
                # print("input type:", input_type)
                input_name = sess.get_inputs()[0].name
                # output_name = sess.get_outputs()[0].name
                output_name = list(['output', '1250', '1245'])

                pre_output = sess.run(
                    output_name,
                    input_feed={sess.get_inputs()[0].name: np.array(frame)})
                output = [[pre_output[0], pre_output[1]], [pre_output[2]]]
                output_t = model(frame)
                # Post-process for output from onnx model
                upsample = torch.nn.Upsample(size=input_size,
                                             mode='bilinear',
                                             align_corners=True)
                upsample_output = upsample(
                    torch.tensor(output[0][-1][0]).unsqueeze(0))
                upsample_output = upsample_output.squeeze()
                upsample_output = upsample_output.permute(1, 2,
                                                          0)  # CHW -> HWC

                logits_result = transform_logits(
                    upsample_output.data.cpu().numpy(),
                    c,
                    s,
                    w,
                    h,
                    input_size=input_size)
                parsing_result = np.argmax(logits_result, axis=2)
                output_img = Image.fromarray(
                    np.asarray(parsing_result, dtype=np.uint8))
                out_img_o = np.array(output_img)
                output_img_o = color_man.G2C(out_img_o)
                # Post-process for torch model
                upsample_output = upsample(
                    torch.tensor(output_t[0][-1][0]).unsqueeze(0))
                upsample_output = upsample_output.squeeze()
                upsample_output = upsample_output.permute(1, 2,
                                                          0)  # CHW -> HWC

                logits_result = transform_logits(
                    upsample_output.data.cpu().numpy(),
                    c,
                    s,
                    w,
                    h,
                    input_size=input_size)
                parsing_result = np.argmax(logits_result, axis=2)
                output_img = Image.fromarray(
                    np.asarray(parsing_result, dtype=np.uint8))
                out_t_min = np.array(output_img)
                out_img_t = np.array(output_img)
                output_img_t = color_man.G2C(out_img_t)

                # final = cv2.hconcat(output_img_t, out_img_t-out_img_o)
                final = cv2.hconcat([
                    output_img_t,
                    cv2.cvtColor(out_img_t - out_img_o, cv2.COLOR_GRAY2RGB) *
                    100, output_img_o
                ])

                out.write(np.array(final))
                cv2.imshow("Tanned", np.array(final))
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

                if args.logits:
                    logits_result_path = os.path.join(args.output_dir,
                                                      img_name[:-4] + '.npy')
                    np.save(logits_result_path, logits_result)
                count += 1
            else:
                break
        end = time.time()
        cap.release()
        out.release()
        print(
            "Processed {} images using {:.5} seconds, average each image took {:.5} seconds"
            .format(count, end - start, (end - start) / (count + 0.1)))
    return
コード例 #8
0
def main():
    args = get_arguments()

    gpus = [int(i) for i in args.gpu.split(',')]
    assert len(gpus) == 1
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    num_classes = dataset_settings[args.dataset]['num_classes']
    input_size = dataset_settings[args.dataset]['input_size']
    label = dataset_settings[args.dataset]['label']
    print("Evaluating total class number {} with {}".format(
        num_classes, label))

    model = networks.init_model('resnet101',
                                num_classes=num_classes,
                                pretrained=None)

    state_dict = torch.load(args.model_restore)['state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.cuda()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.406, 0.456, 0.485],
                             std=[0.225, 0.224, 0.229])
    ])
    dataset = SimpleFolderDataset(root=args.input_dir,
                                  input_size=input_size,
                                  transform=transform)
    dataloader = DataLoader(dataset)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    palette = get_palette(num_classes)
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(dataloader)):
            image, meta = batch
            img_name = meta['name'][0]
            img_path = meta['img_path'][0]
            c = meta['center'].numpy()[0]
            s = meta['scale'].numpy()[0]
            w = meta['width'].numpy()[0]
            h = meta['height'].numpy()[0]

            output = model(image.cuda())
            upsample = torch.nn.Upsample(size=input_size,
                                         mode='bilinear',
                                         align_corners=True)
            upsample_output = upsample(output[0][-1][0].unsqueeze(0))
            upsample_output = upsample_output.squeeze()
            upsample_output = upsample_output.permute(1, 2, 0)  # CHW -> HWC

            logits_result = transform_logits(
                upsample_output.data.cpu().numpy(),
                c,
                s,
                w,
                h,
                input_size=input_size)
            parsing_result = np.argmax(logits_result, axis=2)

            output_img = Image.fromarray(
                np.asarray(parsing_result, dtype=np.uint8))
            output_img.putpalette(palette)
            png_path = os.path.join(args.output_dir, img_name[:-4] + '.png')
            output_img.save(png_path)
            # 'lip': {
            #     'input_size': [473, 473],
            #     'num_classes': 20,
            #     'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
            #               'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
            #               'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
            # },

            parsing_result = (parsing_result >= 5) & (parsing_result != 13)
            parsing_result = parsing_result.astype(int)
            parsing_result = parsing_result * 255

            org_img = Image.open(img_path)
            f = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
            org_img.putalpha(f)
            org_img = np.array(org_img)

            # https://stackoverflow.com/a/55973647/1513627
            # Alpha -> Green
            org_img[org_img[..., -1] == 0] = [0, 255, 0, 0]
            jpg_path = os.path.join(args.output_dir, img_name[:-4] + '.jpg')
            Image.fromarray(org_img).convert('RGB').save(jpg_path)

            if args.logits:
                logits_result_path = os.path.join(args.output_dir,
                                                  img_name[:-4] + '.npy')
                np.save(logits_result_path, logits_result)
    return
コード例 #9
0
def main():
    args = get_arguments()
    print(args)

    start_epoch = 0
    cycle_n = 0

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
        json.dump(vars(args), opt_file)

    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    input_size = list(map(int, args.input_size.split(',')))

    cudnn.enabled = True
    cudnn.benchmark = True

    # Model Initialization
    AugmentCE2P = networks.init_model(args.arch,
                                      num_classes=args.num_classes,
                                      pretrained=args.imagenet_pretrain)
    model = DataParallelModel(AugmentCE2P)
    model.cuda()

    IMAGE_MEAN = AugmentCE2P.mean
    IMAGE_STD = AugmentCE2P.std
    INPUT_SPACE = AugmentCE2P.input_space
    print('image mean: {}'.format(IMAGE_MEAN))
    print('image std: {}'.format(IMAGE_STD))
    print('input space:{}'.format(INPUT_SPACE))

    restore_from = args.model_restore
    if os.path.exists(restore_from):
        print('Resume training from {}'.format(restore_from))
        checkpoint = torch.load(restore_from)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    SCHP_AugmentCE2P = networks.init_model(args.arch,
                                           num_classes=args.num_classes,
                                           pretrained=args.imagenet_pretrain)
    schp_model = DataParallelModel(SCHP_AugmentCE2P)
    schp_model.cuda()

    if os.path.exists(args.schp_restore):
        print('Resuming schp checkpoint from {}'.format(args.schp_restore))
        schp_checkpoint = torch.load(args.schp_restore)
        schp_model_state_dict = schp_checkpoint['state_dict']
        cycle_n = schp_checkpoint['cycle_n']
        schp_model.load_state_dict(schp_model_state_dict)

    # Loss Function
    criterion = CriterionAll(lambda_1=args.lambda_s,
                             lambda_2=args.lambda_e,
                             lambda_3=args.lambda_c,
                             num_classes=args.num_classes)
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    # Data Loader
    if INPUT_SPACE == 'BGR':
        print('BGR Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])

    elif INPUT_SPACE == 'RGB':
        print('RGB Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            BGR2RGB_transform(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])

    train_dataset = LIPDataSet(args.data_dir,
                               args.split_name,
                               crop_size=input_size,
                               transform=transform)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size * len(gpus),
                                   num_workers=16,
                                   shuffle=True,
                                   pin_memory=True,
                                   drop_last=True)
    print('Total training samples: {}'.format(len(train_dataset)))

    # Optimizer Initialization
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    lr_scheduler = SGDRScheduler(optimizer,
                                 total_epoch=args.epochs,
                                 eta_min=args.learning_rate / 100,
                                 warmup_epoch=10,
                                 start_cyclical=args.schp_start,
                                 cyclical_base_lr=args.learning_rate / 2,
                                 cyclical_epoch=args.cycle_epochs)

    total_iters = args.epochs * len(train_loader)
    start = timeit.default_timer()
    for epoch in range(start_epoch, args.epochs):
        lr_scheduler.step(epoch=epoch)
        lr = lr_scheduler.get_lr()[0]

        model.train()
        for i_iter, batch in enumerate(train_loader):
            i_iter += len(train_loader) * epoch

            images, labels, _ = batch
            labels = labels.cuda(non_blocking=True)

            edges = generate_edge_tensor(labels)
            labels = labels.type(torch.cuda.LongTensor)
            edges = edges.type(torch.cuda.LongTensor)

            preds = model(images)

            # Online Self Correction Cycle with Label Refinement
            if cycle_n >= 1:
                with torch.no_grad():
                    soft_preds = schp_model(images)
                    soft_parsing = []
                    soft_edge = []
                    for soft_pred in soft_preds:
                        soft_parsing.append(soft_pred[0][-1])
                        soft_edge.append(soft_pred[1][-1])
                    soft_preds = torch.cat(soft_parsing, dim=0)
                    soft_edges = torch.cat(soft_edge, dim=0)
            else:
                soft_preds = None
                soft_edges = None

            loss = criterion(preds, [labels, edges, soft_preds, soft_edges],
                             cycle_n)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                print('iter = {} of {} completed, lr = {}, loss = {}'.format(
                    i_iter, total_iters, lr,
                    loss.data.cpu().numpy()))
        if (epoch + 1) % (args.eval_epochs) == 0:
            schp.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                },
                False,
                args.log_dir,
                filename='checkpoint_{}.pth.tar'.format(epoch + 1))

        # Self Correction Cycle with Model Aggregation
        if (epoch + 1) >= args.schp_start and (
                epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
            print('Self-correction cycle number {}'.format(cycle_n))
            schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
            cycle_n += 1
            schp.bn_re_estimate(train_loader, schp_model)
            schp.save_schp_checkpoint(
                {
                    'state_dict': schp_model.state_dict(),
                    'cycle_n': cycle_n,
                },
                False,
                args.log_dir,
                filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))

        torch.cuda.empty_cache()
        end = timeit.default_timer()
        print('epoch = {} of {} completed using {} s'.format(
            epoch, args.epochs, (end - start) / (epoch - start_epoch + 1)))

    end = timeit.default_timer()
    print('Training Finished in {} seconds'.format(end - start))