Ejemplo n.º 1
0
def main(args):
    # Image preprocessing
    transform = transforms.Compose([transforms.ToTensor()])

    # Load vocabulary wrapper

    # Build the models
    #CUDA = torch.cuda.is_available()

    num_classes = 80
    yolov3 = Darknet(args.cfg_file)
    yolov3.load_weights(args.weights_file)
    yolov3.net_info["height"] = args.reso
    inp_dim = int(yolov3.net_info["height"])
    assert inp_dim % 32 == 0
    assert inp_dim > 32
    print("yolo-v3 network successfully loaded")

    attribute_size = [15, 7, 3, 5, 8, 4, 15, 7, 3, 5, 3, 3, 4]

    encoder = EncoderClothing(args.embed_size, device, args.roi_size,
                              attribute_size)

    # Prepare an image
    images = "test"

    try:
        list_dir = os.listdir(images)
        #   list_dir.sort(key=lambda x: int(x[:-4]))
        imlist = [
            osp.join(osp.realpath('.'), images, img) for img in list_dir
            if os.path.splitext(img)[1] == '.jpg' or os.path.splitext(img)[1]
            == '.JPG' or os.path.splitext(img)[1] == '.png'
        ]
    except NotADirectoryError:
        imlist = []
        imlist.append(osp.join(osp.realpath('.'), images))
        print('Not a directory error')
    except FileNotFoundError:
        print("No file or directory with the name {}".format(images))
        exit()

    yolov3.to(device)
    encoder.to(device)

    yolov3.eval()
    encoder.eval()

    encoder.load_state_dict(torch.load(args.encoder_path))

    for inx, image in enumerate(imlist):

        #print(image)
        image, orig_img, im_dim = prep_image(image, inp_dim)
        im_dim = torch.FloatTensor(im_dim).repeat(1, 2)

        image_tensor = image.to(device)
        im_dim = im_dim.to(device)

        # Generate an caption from the image
        detections = yolov3(image_tensor, device,
                            True)  # prediction mode for yolo-v3
        detections = write_results(detections,
                                   args.confidence,
                                   num_classes,
                                   device,
                                   nms=True,
                                   nms_conf=args.nms_thresh)
        # original image dimension --> im_dim
        #view_image(detections)

        os.system('clear')
        if type(detections) != int:
            if detections.shape[0]:
                bboxs = detections[:, 1:5].clone()
                im_dim = im_dim.repeat(detections.shape[0], 1)
                scaling_factor = torch.min(inp_dim / im_dim, 1)[0].view(-1, 1)

                detections[:, [1, 3]] -= (
                    inp_dim - scaling_factor * im_dim[:, 0].view(-1, 1)) / 2
                detections[:, [2, 4]] -= (
                    inp_dim - scaling_factor * im_dim[:, 1].view(-1, 1)) / 2

                detections[:, 1:5] /= scaling_factor

                small_object_ratio = torch.FloatTensor(detections.shape[0])

                for i in range(detections.shape[0]):
                    detections[i,
                               [1, 3]] = torch.clamp(detections[i, [1, 3]],
                                                     0.0, im_dim[i, 0])
                    detections[i,
                               [2, 4]] = torch.clamp(detections[i, [2, 4]],
                                                     0.0, im_dim[i, 1])

                    object_area = (detections[i, 3] - detections[i, 1]) * (
                        detections[i, 4] - detections[i, 2])
                    orig_img_area = im_dim[i, 0] * im_dim[i, 1]
                    small_object_ratio[i] = object_area / orig_img_area

                detections = detections[small_object_ratio > 0.02]
                im_dim = im_dim[small_object_ratio > 0.02]

                if detections.size(0) > 0:
                    feature = yolov3.get_feature()
                    feature = feature.repeat(detections.size(0), 1, 1, 1)

                    #orig_img_dim = im_dim[:, 1:]
                    #orig_img_dim = orig_img_dim.repeat(1, 2)

                    scaling_val = 16

                    bboxs /= scaling_val
                    bboxs = bboxs.round()
                    bboxs_index = torch.arange(bboxs.size(0), dtype=torch.int)
                    bboxs_index = bboxs_index.to(device)
                    bboxs = bboxs.to(device)

                    roi_align = RoIAlign(args.roi_size,
                                         args.roi_size,
                                         transform_fpcoor=True).to(device)
                    roi_features = roi_align(feature, bboxs, bboxs_index)
                    #    print(roi_features)
                    #    print(roi_features.size())

                    #roi_features = roi_features.reshape(roi_features.size(0), -1)

                    #roi_align_feature = encoder(roi_features)

                    outputs = encoder(roi_features)
                    #attribute_size = [15, 7, 3, 5, 7, 4, 15, 7, 3, 5, 4, 3, 4]
                    #losses = [criteria[i](outputs[i], targets[i]) for i in range(len(attribute_size))]

                    for i in range(detections.shape[0]):

                        sampled_caption = []
                        #attr_fc = outputs[]
                        for j in range(len(outputs)):
                            #temp = outputs[j][i].data
                            max_index = torch.max(outputs[j][i].data, 0)[1]
                            word = attribute_pool[j][max_index]
                            sampled_caption.append(word)

                        c11 = sampled_caption[11]
                        sampled_caption[11] = sampled_caption[10]
                        sampled_caption[10] = c11

                        sentence = ' '.join(sampled_caption)

                        # again sampling for testing
                        #print ('---------------------------')
                        print(str(i + 1) + ': ' + sentence)
                        write(detections[i], orig_img, sentence, i + 1,
                              coco_classes, colors)
                        #list(map(lambda x: write(x, orig_img, captions), detections[i].unsqueeze(0)))

        cv2.imshow("frame", orig_img)
        key = cv2.waitKey(0)
        os.system('clear')
        if key & 0xFF == ord('q'):
            break
Ejemplo n.º 2
0
def main(args):
    # Image preprocessing
    transform = transforms.Compose([transforms.ToTensor()])

    num_classes = 80
    yolov3 = Darknet(args.cfg_file)
    yolov3.load_weights(args.weights_file)
    yolov3.net_info["height"] = args.reso
    inp_dim = int(yolov3.net_info["height"])
    assert inp_dim % 32 == 0
    assert inp_dim > 32
    print("yolo-v3 network successfully loaded")

    attribute_size = [15, 7, 3, 5, 8, 4, 15, 7, 3, 5, 3, 3, 4]

    encoder = EncoderClothing(args.embed_size, device, args.roi_size,
                              attribute_size)

    yolov3.to(device)
    encoder.to(device)

    yolov3.eval()
    encoder.eval()

    encoder.load_state_dict(torch.load(args.encoder_path))

    #cap = cv2.VideoCapture('demo2.mp4')

    cap = cv2.VideoCapture(0)
    assert cap.isOpened(), 'Cannot capture source'

    frames = 0
    start = time.time()

    counter = Counter()
    color_stream = list()
    pattern_stream = list()
    gender_stream = list()
    season_stream = list()
    class_stream = list()
    sleeves_stream = list()

    ret, frame = cap.read()
    if ret:

        image, orig_img, dim = prep_image2(frame, inp_dim)
        im_dim = torch.FloatTensor(dim).repeat(1, 2)

        image_tensor = image.to(device)
    detections = yolov3(image_tensor, device, True)

    os.system('clear')
    cv2.imshow("frame", orig_img)
    cv2.moveWindow("frame", 50, 50)
    text_img = np.zeros((200, 1750, 3))
    cv2.imshow("text", text_img)
    cv2.moveWindow("text", 50, dim[1] + 110)

    while cap.isOpened():

        ret, frame = cap.read()
        #### ret, frame = ros_message_cam_image()
        if ret:

            image, orig_img, dim = prep_image2(frame, inp_dim)
            im_dim = torch.FloatTensor(dim).repeat(1, 2)

            image_tensor = image.to(device)
            im_dim = im_dim.to(device)

            # Generate an caption from the image
            detections = yolov3(image_tensor, device,
                                True)  # prediction mode for yolo-v3
            detections = write_results(detections,
                                       args.confidence,
                                       num_classes,
                                       device,
                                       nms=True,
                                       nms_conf=args.nms_thresh)

            #### detections = ros_message_rois()
            #### ros_rois --> [0,0, x1, y1, x2, y2]

            # original image dimension --> im_dim

            #view_image(detections)
            text_img = np.zeros((200, 1750, 3))

            if type(detections) != int:
                if detections.shape[0]:
                    bboxs = detections[:, 1:5].clone()

                    im_dim = im_dim.repeat(detections.shape[0], 1)
                    scaling_factor = torch.min(inp_dim / im_dim,
                                               1)[0].view(-1, 1)

                    detections[:, [1, 3]] -= (inp_dim - scaling_factor *
                                              im_dim[:, 0].view(-1, 1)) / 2
                    detections[:, [2, 4]] -= (inp_dim - scaling_factor *
                                              im_dim[:, 1].view(-1, 1)) / 2

                    detections[:, 1:5] /= scaling_factor

                    small_object_ratio = torch.FloatTensor(detections.shape[0])

                    for i in range(detections.shape[0]):
                        detections[i, [1, 3]] = torch.clamp(
                            detections[i, [1, 3]], 0.0, im_dim[i, 0])
                        detections[i, [2, 4]] = torch.clamp(
                            detections[i, [2, 4]], 0.0, im_dim[i, 1])

                        object_area = (detections[i, 3] - detections[i, 1]) * (
                            detections[i, 4] - detections[i, 2])
                        orig_img_area = im_dim[i, 0] * im_dim[i, 1]
                        small_object_ratio[i] = object_area / orig_img_area

                    detections = detections[small_object_ratio > 0.05]
                    im_dim = im_dim[small_object_ratio > 0.05]

                    if detections.size(0) > 0:
                        feature = yolov3.get_feature()
                        feature = feature.repeat(detections.size(0), 1, 1, 1)

                        orig_img_dim = im_dim[:, 1:]
                        orig_img_dim = orig_img_dim.repeat(1, 2)

                        scaling_val = 16

                        bboxs /= scaling_val
                        bboxs = bboxs.round()
                        bboxs_index = torch.arange(bboxs.size(0),
                                                   dtype=torch.int)
                        bboxs_index = bboxs_index.to(device)
                        bboxs = bboxs.to(device)

                        roi_align = RoIAlign(args.roi_size,
                                             args.roi_size,
                                             transform_fpcoor=True).to(device)
                        roi_features = roi_align(feature, bboxs, bboxs_index)

                        outputs = encoder(roi_features)

                        for i in range(detections.shape[0]):

                            sampled_caption = []
                            #attr_fc = outputs[]
                            for j in range(len(outputs)):
                                max_index = torch.max(outputs[j][i].data, 0)[1]
                                word = attribute_pool[j][max_index]
                                sampled_caption.append(word)

                            c11 = sampled_caption[11]
                            sampled_caption[11] = sampled_caption[10]
                            sampled_caption[10] = c11

                            sentence = ' '.join(sampled_caption)

                            sys.stdout.write(
                                '                                                                                        '
                                + '\r')

                            sys.stdout.write(sentence + '             ' + '\r')
                            sys.stdout.flush()
                            write(detections[i], orig_img, sentence, i + 1,
                                  coco_classes, colors)

                            cv2.putText(text_img, sentence, (0, i * 40 + 35),
                                        cv2.FONT_HERSHEY_PLAIN, 2,
                                        [255, 255, 255], 1)

            cv2.imshow("frame", orig_img)
            cv2.imshow("text", text_img)

            key = cv2.waitKey(1)
            if key & 0xFF == ord('q'):
                break
            if key & 0xFF == ord('w'):
                wait(0)
            if key & 0xFF == ord('s'):
                continue
            frames += 1
            #print("FPS of the video is {:5.2f}".format( frames / (time.time() - start)))

        else:
            break