Пример #1
0
def main():
    a = get_args()

    prev_enc = 0

    def train(i):
        loss = 0

        noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4],
                                      1).cuda() if a.noise > 0 else None
        img_out = image_f(noise)

        if a.sharp != 0:
            lx = torch.mean(
                torch.abs(img_out[0, :, :, 1:] - img_out[0, :, :, :-1]))
            ly = torch.mean(
                torch.abs(img_out[0, :, 1:, :] - img_out[0, :, :-1, :]))
            loss -= a.sharp * (ly + lx)

        micro = 1 - a.macro if a.in_txt2 is None else False
        imgs_sliced = slice_imgs([img_out],
                                 a.samples,
                                 a.modsize,
                                 trform_f,
                                 a.align,
                                 micro=micro)
        out_enc = model_clip.encode_image(imgs_sliced[-1])
        if a.diverse != 0:
            imgs_sliced = slice_imgs([image_f(noise)],
                                     a.samples,
                                     a.modsize,
                                     trform_f,
                                     a.align,
                                     micro=micro)
            out_enc2 = model_clip.encode_image(imgs_sliced[-1])
            loss += a.diverse * torch.cosine_similarity(
                out_enc, out_enc2, dim=-1).mean()
            del out_enc2
            torch.cuda.empty_cache()
        if a.in_img is not None and os.path.isfile(a.in_img):  # input image
            loss += sign * 0.5 * torch.cosine_similarity(
                img_enc, out_enc, dim=-1).mean()
        if a.in_txt is not None:  # input text
            loss += sign * torch.cosine_similarity(txt_enc, out_enc,
                                                   dim=-1).mean()
            if a.notext > 0:
                loss -= sign * a.notext * torch.cosine_similarity(
                    txt_plot_enc, out_enc, dim=-1).mean()
        if a.in_txt0 is not None:  # subtract text
            loss += -sign * torch.cosine_similarity(txt_enc0, out_enc,
                                                    dim=-1).mean()
        if a.sync > 0 and a.in_img is not None and os.path.isfile(
                a.in_img):  # image composition
            prog_sync = (a.steps // a.fstep - i) / (a.steps // a.fstep)
            loss += prog_sync * a.sync * sim_loss(F.interpolate(
                img_out, sim_size).float(),
                                                  img_in,
                                                  normalize=True).squeeze()
        if a.in_txt2 is not None:  # input text for micro details
            imgs_sliced = slice_imgs([img_out],
                                     a.samples,
                                     a.modsize,
                                     trform_f,
                                     a.align,
                                     micro=True)
            out_enc2 = model_clip.encode_image(imgs_sliced[-1])
            loss += sign * torch.cosine_similarity(txt_enc2, out_enc2,
                                                   dim=-1).mean()
            del out_enc2
            torch.cuda.empty_cache()
        if a.expand > 0:
            global prev_enc
            if i > 0:
                loss += a.expand * torch.cosine_similarity(
                    out_enc, prev_enc, dim=-1).mean()
            prev_enc = out_enc.detach()

        del img_out, imgs_sliced, out_enc
        torch.cuda.empty_cache()
        assert not isinstance(loss, int), ' Loss not defined, check the inputs'

        if a.prog is True:
            lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
            for g in optimizer.param_groups:
                g['lr'] = lr_cur

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % a.fstep == 0:
            with torch.no_grad():
                img = image_f(contrast=a.contrast).cpu().numpy()[0]
            if (a.sync > 0 and a.in_img is not None) or a.sharp != 0:
                img = img**1.3  # empirical tone mapping
            checkout(img,
                     os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)),
                     verbose=a.verbose)
            pbar.upd()

    # Load CLIP models
    model_clip, _ = clip.load(a.model)
    if a.verbose is True: print(' using model', a.model)
    xmem = {'RN50': 0.5, 'RN50x4': 0.16, 'RN101': 0.33}
    if 'RN' in a.model:
        a.samples = int(a.samples * xmem[a.model])

    if a.multilang is True:
        model_lang = SentenceTransformer(
            'clip-ViT-B-32-multilingual-v1').cuda()

    def enc_text(txt):
        if a.multilang is True:
            emb = model_lang.encode([txt],
                                    convert_to_tensor=True,
                                    show_progress_bar=False)
        else:
            emb = model_clip.encode_text(clip.tokenize(txt).cuda())
        return emb.detach().clone()

    if a.diverse != 0:
        a.samples = int(a.samples * 0.5)
    if a.sync > 0:
        a.samples = int(a.samples * 0.5)

    if a.transform is True:
        trform_f = transforms.transforms_custom
        a.samples = int(a.samples * 0.95)
    else:
        trform_f = transforms.normalize()

    out_name = []
    if a.in_txt is not None:
        if a.verbose is True: print(' ref text: ', basename(a.in_txt))
        if a.translate:
            translator = Translator()
            a.in_txt = translator.translate(a.in_txt, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt)
        txt_enc = enc_text(a.in_txt)
        out_name.append(txt_clean(a.in_txt))

        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(a.in_txt, a.modsize) /
                                        255.).unsqueeze(0).permute(0, 3, 1,
                                                                   2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()

    if a.in_txt2 is not None:
        if a.verbose is True: print(' micro text:', basename(a.in_txt2))
        a.samples = int(a.samples * 0.75)
        if a.translate:
            translator = Translator()
            a.in_txt2 = translator.translate(a.in_txt2, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt2)
        txt_enc2 = enc_text(a.in_txt2)
        out_name.append(txt_clean(a.in_txt2))

    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', basename(a.in_txt0))
        a.samples = int(a.samples * 0.75)
        if a.translate:
            translator = Translator()
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt0)
        txt_enc0 = enc_text(a.in_txt0)
        out_name.append('off-' + txt_clean(a.in_txt0))

    if a.multilang is True: del model_lang

    if a.in_img is not None and os.path.isfile(a.in_img):
        if a.verbose is True: print(' ref image:', basename(a.in_img))
        img_in = torch.from_numpy(
            img_read(a.in_img) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda()
        img_in = img_in[:, :3, :, :]  # fix rgb channels
        in_sliced = slice_imgs([img_in],
                               a.samples,
                               a.modsize,
                               transforms.normalize(),
                               a.align,
                               micro=False)[0]
        img_enc = model_clip.encode_image(in_sliced).detach().clone()
        if a.sync > 0:
            sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()
            sim_size = [s // 2 for s in a.size]
            img_in = F.interpolate(img_in, sim_size).float()
        else:
            del img_in
        del in_sliced
        torch.cuda.empty_cache()
        out_name.append(basename(a.in_img).replace(' ', '_'))

    params, image_f = fft_image([1, 3, *a.size],
                                resume=a.resume,
                                decay_power=a.decay)
    image_f = to_valid_rgb(image_f, colors=a.colors)

    if a.prog is True:
        lr1 = a.lrate * 2
        lr0 = lr1 * 0.01
    else:
        lr0 = a.lrate
    optimizer = torch.optim.Adam(params, lr0)
    sign = 1. if a.invert is True else -1.

    if a.verbose is True: print(' samples:', a.samples)
    out_name = '-'.join(out_name)
    out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
    tempdir = os.path.join(a.out_dir, out_name)
    os.makedirs(tempdir, exist_ok=True)

    pbar = ProgressBar(a.steps // a.fstep)
    for i in range(a.steps):
        train(i)

    os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' %
              (tempdir, os.path.join(a.out_dir, out_name)))
    shutil.copy(
        img_list(tempdir)[-1],
        os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps)))
    if a.save_pt is True:
        torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
Пример #2
0
    def process(txt, num):

        sd = 0.01
        if a.keep > 0: sd = a.keep + (1-a.keep) * sd
        params, image_f = fft_image([1, 3, *a.size], resume='init.pt', sd=sd, decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors = a.colors)

        if a.prog is True:
            lr1 = a.lrate * 2
            lr0 = a.lrate * 0.1
        else:
            lr0 = a.lrate
        optimizer = torch.optim.Adam(params, lr0)
    
        if a.verbose is True: print(' ref text: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        if a.multilang is True:
            model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
            txt_enc = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False).detach().clone()
            del model_lang
        else:
            txt_enc = model_clip.encode_text(clip.tokenize(txt).cuda()).detach().clone()
        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(txt, a.modsize)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()
        else: txt_plot_enc = None

        out_name = '%03d-%s' % (num+1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)
        
        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0

            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)
            
            if a.sharp != 0:
                lx = torch.mean(torch.abs(img_out[0,:,:,1:] - img_out[0,:,:,:-1]))
                ly = torch.mean(torch.abs(img_out[0,:,1:,:] - img_out[0,:,:-1,:]))
                loss -= a.sharp * (ly+lx)

            imgs_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, micro=1.)
            out_enc = model_clip.encode_image(imgs_sliced[-1])
            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.notext > 0:
                loss += a.notext * torch.cosine_similarity(txt_plot_enc, out_enc, dim=-1).mean()
            if a.diverse != 0:
                imgs_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, micro=1.)
                out_enc2 = model_clip.encode_image(imgs_sliced[-1])
                loss += a.diverse * torch.cosine_similarity(out_enc, out_enc2, dim=-1).mean()
                del out_enc2; torch.cuda.empty_cache()
            if a.expand > 0:
                global prev_enc
                if i > 0:
                    loss += a.expand * torch.cosine_similarity(out_enc, prev_enc, dim=-1).mean()
                prev_enc = out_enc.detach().clone()
            if a.in_txt0 is not None: # subtract text
                loss += torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
            del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

            if a.prog is True:
                lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
                for g in optimizer.param_groups: 
                    g['lr'] = lr_cur
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                if a.sharp != 0:
                    img = img **1.3 # empirical tone mapping
                checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)), verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep > 0:
            global params_start, params_ema
            params_ema = ema(params_ema, params[0].detach().clone(), num+1)
            torch.save((1-a.keep) * params_start + a.keep * params_ema, 'init.pt')
        
        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(img_list(tempdir)[-1], os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name)))
Пример #3
0
def detect(save_img=True):
    valid = 0
    max_valid = 30
    cameraAutoRotate(10, 0)

    img_size = (
        320, 192
    ) if ONNX_EXPORT else opt.img_size  # (320, 192) or (416, 256) or (608, 352) for (height, width)
    out, source, weights, half, view_img, save_txt = opt.output, opt.source, opt.weights, opt.half, opt.view_img, opt.save_txt
    webcam = source == '0' or source.startswith('rtsp') or source.startswith(
        'http') or source.endswith('.txt')

    # Initialize
    device = torch_utils.select_device(
        device='cpu' if ONNX_EXPORT else opt.device)
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder

    # Initialize model
    model = Darknet(opt.cfg, img_size)

    # Load weights
    attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        model.load_state_dict(
            torch.load(weights, map_location=device)['model'])
    else:  # darknet format
        load_darknet_weights(model, weights)

    # Eval mode
    model.to(device).eval()

    # Half precision
    half = half and device.type != 'cpu'  # half precision only supported on CUDA
    if half:
        model.half()

    # Set Dataloader
    vid_path, vid_writer = None, None
    if source == 'fake':
        save_img = False
        view_img = True
        dataset = LoadFake(img_size=img_size, half=half)
    elif webcam:
        save_img = False
        view_img = True
        torch.backends.cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=img_size, half=half)
    else:
        save_img = True
        dataset = LoadImages(source, img_size=img_size, half=half)

    # Get names and colors
    names = load_classes(opt.names)
    colors = [[random.randint(0, 255) for _ in range(3)]
              for _ in range(len(names))]

    # Run inference
    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        t = time.time()

        # Get detections
        img = torch.from_numpy(img).to(device)
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        pred = model(img)[0]

        if opt.half:
            pred = pred.float()

        # Apply NMS
        pred = non_max_suppression(pred,
                                   opt.conf_thres,
                                   opt.iou_thres,
                                   classes=opt.classes,
                                   agnostic=opt.agnostic_nms)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i]
            else:
                p, s, im0 = path, '', im0s

            save_path = str(Path(out) / Path(p).name)
            s += '%gx%g ' % img.shape[2:]  # print string
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4],
                                          im0.shape).round()
                # print(det[:, :5])
                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                bbox_xywh = []
                confs = []

                # Write results
                for *xyxy, conf, cls in det:
                    img_h, img_w, _ = im0.shape  # get image shape
                    bbox_left = min([xyxy[0].item(), xyxy[2].item()])
                    bbox_top = min([xyxy[1].item(), xyxy[3].item()])
                    bbox_w = abs(xyxy[0].item() - xyxy[2].item())
                    bbox_h = abs(xyxy[1].item() - xyxy[3].item())
                    x_c, y_c, bbox_w, bbox_h = bbox_rel(
                        img_w, img_h, bbox_left, bbox_top, bbox_w, bbox_h)
                    #print(x_c, y_c, bbox_w, bbox_h)
                    obj = [x_c, y_c, bbox_w, bbox_h]
                    bbox_xywh.append(obj)
                    confs.append([conf.item()])
                    label = '%s %.2f' % (names[int(cls)], conf)
                    #
                    #print('bboxes')
                    #print(torch.Tensor(bbox_xywh))
                    #print('confs')
                    #print(torch.Tensor(confs))
                    outputs = deepsort.update((torch.Tensor(bbox_xywh)),
                                              (torch.Tensor(confs)), im0)
                    if len(outputs) > 0:
                        bbox_xyxy = outputs[:, :4]
                        identities = outputs[:, -1]
                        draw_boxes(im0, bbox_xyxy, identities)
                    #print('\n\n\t\ttracked objects')
                    #print(outputs)

                valid = min(valid + 1, max_valid)
                if valid == max_valid:
                    cx = (int(xyxy[0].item()) + int(xyxy[2].item())) // 2
                    cy = (int(xyxy[1].item()) + int(xyxy[3].item())) // 2
                    cameraLocate(*FC.pix2ang(cx, cy))
                else:
                    cx = cy = -1
            else:
                valid = max(valid - 1, 0)
                cx = cy = -1
                if valid == 0 and getHorizontalRotateSpeed() == 0:
                    cameraAutoRotate(10, 0)

            # Print time (inference + NMS)
            # print('%sDone. (%.3fs)' % (s, time.time() - t))

            # Stream results
            if view_img:
                plot_text((0, 20),
                          im0,
                          text=f"valid {valid}/{max_valid}",
                          color=(255, 0, 0))
                # if cx != -1 and cy != -1:
                # plot_cross((cx, cy), im0, color=(255,0,0))
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)
                else:
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release(
                            )  # release previous video writer

                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(
                            save_path, cv2.VideoWriter_fourcc(*opt.fourcc),
                            fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        print('Results saved to %s' % os.getcwd() + os.sep + out)
        if platform == 'darwin':  # MacOS
            os.system('open ' + out + ' ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))