Beispiel #1
0
    def create_components_vecs(self, imgs, boxes, obj_to_img, objs, obj_vecs,
                               features):
        O = objs.size(0)
        box_vecs = obj_vecs
        mask_vecs = obj_vecs
        layout_noise = torch.randn((1, self.mask_noise_dim), dtype=mask_vecs.dtype, device=mask_vecs.device) \
            .repeat((O, 1)) \
            .view(O, self.mask_noise_dim)
        mask_vecs = torch.cat([mask_vecs, layout_noise], dim=1)

        # create encoding
        crops = None
        if features is None:
            crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
            obj_repr = self.repr_net(self.image_encoder(crops))
        else:
            # Only in inference time
            # obj_repr = self.repr_net(mask_vecs)
            crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
            obj_repr = self.repr_net(self.image_encoder(crops))
            for ind, feature in enumerate(features):
                if feature is not None:
                    obj_repr[ind, :] = feature
        # create one-hot vector for label map
        one_hot_size = (O, self.num_objs)
        one_hot_obj = torch.zeros(one_hot_size,
                                  dtype=obj_repr.dtype,
                                  device=obj_repr.device)
        one_hot_obj = one_hot_obj.scatter_(1, objs.view(-1, 1).long(), 1.0)
        layout_vecs = torch.cat([one_hot_obj, obj_repr], dim=1)

        wrong_objs_rep = self.fake_pool.query(objs, obj_repr)
        wrong_layout_vecs = torch.cat([one_hot_obj, wrong_objs_rep], dim=1)
        return box_vecs, mask_vecs, layout_vecs, wrong_layout_vecs, obj_repr, crops
Beispiel #2
0
def main(opt):
    name = 'features'
    checkpoint = torch.load(opt.checkpoint)
    rep_size = checkpoint['model_kwargs']['rep_size']
    vocab = checkpoint['model_kwargs']['vocab']
    num_objs = len(vocab['object_to_idx'])
    model = build_model(opt, checkpoint)
    train, val, test = VG.splits(transform=transforms.Compose([
        transforms.Resize(opt.image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]))
    train_loader, test_loader = VGDataLoader.splits(train, test, batch_size=opt.batch_size,
                                                    num_workers=opt.loader_num_workers,
                                                    num_gpus=1)
    loader = test_loader
    print(train.ind_to_classes)

    save_path = os.path.dirname(opt.checkpoint)

    ########### Encode features ###########
    counter = 0
    max_counter = 1000000000
    print('begin')
    with torch.no_grad():
        features = {}
        for label in range(num_objs):
            features[label] = np.zeros((0, rep_size))
        for i, data in enumerate(loader):
            if counter >= max_counter:
                break
            # (all_imgs, all_objs, all_boxes, all_masks, all_triples,
            #            all_obj_to_img, all_triple_to_img, all_attributes)
            imgs = data[0].cuda()
            objs = data[1]
            objs = [j.item() for j in objs]
            boxes = data[2].cuda()
            obj_to_img = data[5].cuda()
            crops = crop_bbox_batch(imgs, boxes, obj_to_img, opt.object_size)
            feat = model.repr_net(model.image_encoder(crops)).cpu()
            for ind, label in enumerate(objs):
                features[label] = np.append(features[label], feat[ind].view(1, -1), axis=0)
            counter += len(objs)

            # print('%d / %d images' % (i + 1, dataset_size))
        save_name = os.path.join(save_path, name + '.npy')
        np.save(save_name, features)

    ############## Clustering ###########
    print('begin clustering')
    load_name = os.path.join(save_path, name + '.npy')
    features = np.load(load_name).item()
    cluster(features, num_objs, 100, save_path)
    cluster(features, num_objs, 10, save_path)
    cluster(features, num_objs, 1, save_path)
def main(opt):
    name = 'features'
    checkpoint = torch.load(opt.checkpoint)
    rep_size = checkpoint['model_kwargs']['rep_size']
    vocab = checkpoint['model_kwargs']['vocab']
    num_objs = len(vocab['object_to_idx'])
    model = build_model(opt, checkpoint)
    loader = build_loader(opt, checkpoint)

    save_path = os.path.dirname(opt.checkpoint)

    ########### Encode features ###########
    counter = 0
    max_counter = 1000000000
    print('begin')
    with torch.no_grad():
        features = {}
        for label in range(num_objs):
            features[label] = np.zeros((0, rep_size))
        for i, data in enumerate(loader):
            if counter >= max_counter:
                break
            imgs = data[0].cuda()
            objs = data[1]
            objs = [j.item() for j in objs]
            boxes = data[2].cuda()
            obj_to_img = data[5].cuda()
            crops = crop_bbox_batch(imgs, boxes, obj_to_img, opt.object_size)
            feat = model.repr_net(model.image_encoder(crops)).cpu()
            for ind, label in enumerate(objs):
                features[label] = np.append(features[label],
                                            feat[ind].view(1, -1),
                                            axis=0)
            counter += len(objs)

            # print('%d / %d images' % (i + 1, dataset_size))
        save_name = os.path.join(save_path, name + '.npy')
        np.save(save_name, features)

    ############## Clustering ###########
    print('begin clustering')
    load_name = os.path.join(save_path, name + '.npy')
    features = np.load(load_name).item()
    cluster(features, num_objs, 100, save_path)
    cluster(features, num_objs, 10, save_path)
    cluster(features, num_objs, 1, save_path)
Beispiel #4
0
def train_model(model,
                test_dataloader,
                val_dataloader,
                criterion,
                optimizer,
                scheduler,
                use_gpu,
                num_epochs=10,
                input_shape=224):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0
    device = 'cuda' if use_gpu else 'cpu'

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
                dataloader = test_dataloader
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = val_dataloader

            running_loss = 0.0
            running_corrects = 0
            objects_len = 0

            # Iterate over data.
            for data in dataloader:
                # get the inputs
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = data
                imgs = imgs.to(device)
                boxes = boxes.to(device)
                obj_to_img = obj_to_img.to(device)
                labels = objs.to(device)

                objects_len += obj_to_img.shape[0]

                with torch.no_grad():
                    crops = crop_bbox_batch(imgs, boxes, obj_to_img,
                                            input_shape)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(crops)
                if type(outputs) == tuple:
                    outputs, _ = outputs
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels)

            epoch_loss = running_loss / objects_len
            epoch_acc = running_corrects.item() / objects_len

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
    def create_components_vecs(self,
                               imgs,
                               boxes,
                               obj_to_img,
                               objs,
                               obj_vecs,
                               features,
                               drop_box_idx=None,
                               drop_feat_idx=None,
                               jitter=(None, None, None),
                               jitter_range=((-0.05, 0.05), (-0.05, 0.05)),
                               src_image=None):
        O = objs.size(0)
        box_vecs = obj_vecs
        mask_vecs = obj_vecs
        layout_noise = torch.randn((1, self.mask_noise_dim), dtype=mask_vecs.dtype, device=mask_vecs.device) \
            .repeat((O, 1)) \
            .view(O, self.mask_noise_dim)
        mask_vecs = torch.cat([mask_vecs, layout_noise], dim=1)

        jitterFeat = False
        # create encoding
        if features is None:
            if jitterFeat:
                if obj_to_img is None:
                    obj_to_img = torch.zeros(O,
                                             dtype=objs.dtype,
                                             device=objs.device)
                    imgbox_idx = -1
                else:
                    imgbox_idx = torch.zeros(src_image.size(0),
                                             dtype=torch.int64)
                    for i in range(src_image.size(0)):
                        imgbox_idx[i] = (obj_to_img == i).nonzero()[-1]

                add_jitter_bbox, add_jitter_layout, add_jitter_feats = jitter  # unpack
                jitter_range_bbox, jitter_range_layout = jitter_range

                # Bounding boxes ----------------------------------------------------------
                box_ones = torch.ones([O, 1],
                                      dtype=boxes.dtype,
                                      device=boxes.device)

                if drop_box_idx is not None:
                    box_keep = drop_box_idx
                else:
                    # drop random box(es)
                    box_keep = F.dropout(box_ones, self.p, True,
                                         False) * (1 - self.p)

                # image obj cannot be dropped
                box_keep[imgbox_idx, :] = 1

                if add_jitter_bbox is not None:
                    boxes_gt = jitter_bbox(
                        boxes,
                        p=add_jitter_bbox,
                        noise_range=jitter_range_bbox,
                        eval_mode=True)  # uses default settings

                boxes_prior = boxes * box_keep

                # Object features ----------------------------------------------------------
                if drop_feat_idx is not None:
                    feats_keep = drop_feat_idx

                else:
                    feats_keep = F.dropout(box_ones, self.p, True,
                                           False) * (1 - self.p)
                # print(feats_keep)
                # image obj feats should be dropped
                feats_keep[
                    imgbox_idx, :] = 1  # should they be dropped or should they not?

                obj_crop, src_image, generated = get_cropped_objs(
                    src_image, boxes, obj_to_img, box_keep, feats_keep, True)
            crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
            #print(crops.shape, obj_crop.shape)
            obj_repr = self.repr_net(self.image_encoder(crops))
        else:
            # Only in inference time
            obj_repr = self.repr_net(mask_vecs)
            for ind, feature in enumerate(features):
                if feature is not None:
                    obj_repr[ind, :] = feature
        # create one-hot vector for label map
        #obj_repr = obj_repr[:,:-1]
        one_hot_size = (O, self.num_objs)
        one_hot_obj = torch.zeros(one_hot_size,
                                  dtype=obj_repr.dtype,
                                  device=obj_repr.device)
        one_hot_obj = one_hot_obj.scatter_(1, objs.view(-1, 1).long(), 1.0)
        layout_vecs = torch.cat([one_hot_obj, obj_repr], dim=1)

        wrong_objs_rep = self.fake_pool.query(objs, obj_repr)
        wrong_layout_vecs = torch.cat([one_hot_obj, wrong_objs_rep], dim=1)
        return box_vecs, mask_vecs, layout_vecs, wrong_layout_vecs
def run_model(args, checkpoint, output_dir, loader=None):
    if args.save_graphs:
        from scene_generation.vis import draw_scene_graph
    dirname = os.path.dirname(args.checkpoint)
    features = None
    if not args.use_gt_textures:
        features_path = os.path.join(dirname, 'features_clustered_001.npy')
        print(features_path)
        if os.path.isfile(features_path):
            features = np.load(features_path, allow_pickle=True).item()
        else:
            raise ValueError('No features file')
    with torch.no_grad():
        vocab = checkpoint['model_kwargs']['vocab']
        model = build_model(args, checkpoint)
        if loader is None:
            loader = build_loader(args, checkpoint, vocab['is_panoptic'])
        accuracy_model = None
        if args.accuracy_model_path is not None and os.path.isfile(
                args.accuracy_model_path):
            accuracy_model = load_model(args.accuracy_model_path)

        img_dir = makedir(output_dir, 'images')
        graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
        gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
        layout_dir = makedir(output_dir, 'layouts', args.save_layout)

        img_idx = 0
        total_iou = 0
        total_boxes = 0
        r_05 = 0
        r_03 = 0
        corrects = 0
        real_objects_count = 0
        num_objs = model.num_objs
        colors = torch.randint(0, 256, [num_objs, 3]).float()
        for batch in loader:
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = [
                x.cuda() for x in batch
            ]

            imgs_gt = imagenet_deprocess_batch(imgs)

            if args.use_gt_masks:
                masks_gt = masks
            else:
                masks_gt = None
            if args.use_gt_textures:
                all_features = None
            else:
                all_features = []
                for obj_name in objs:
                    obj_feature = features[obj_name.item()]
                    random_index = randint(0, obj_feature.shape[0] - 1)
                    feat = torch.from_numpy(obj_feature[random_index, :]).type(
                        torch.float32).cuda()
                    all_features.append(feat)
            if not args.use_gt_attr:
                attributes = torch.zeros_like(attributes)

            # Run the model with predicted masks
            model_out = model(imgs,
                              objs,
                              triples,
                              obj_to_img,
                              boxes_gt=boxes,
                              masks_gt=masks_gt,
                              attributes=attributes,
                              test_mode=True,
                              use_gt_box=args.use_gt_boxes,
                              features=all_features)
            imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out

            if accuracy_model is not None:
                if args.use_gt_boxes:
                    crops = crop_bbox_batch(imgs_pred, boxes, obj_to_img, 224)
                else:
                    crops = crop_bbox_batch(imgs_pred, boxes_pred, obj_to_img,
                                            224)

                outputs = accuracy_model(crops)
                if type(outputs) == tuple:
                    outputs, _ = outputs
                _, preds = torch.max(outputs, 1)

                # statistics
                for pred, label in zip(preds, objs):
                    if label.item() != 0:
                        real_objects_count += 1
                        corrects += 1 if pred.item() == label.item() else 0

            # Remove the __image__ object
            boxes_pred_no_image = []
            boxes_gt_no_image = []
            for o_index in range(len(obj_to_img)):
                if o_index < len(obj_to_img) - 1 and obj_to_img[
                        o_index] == obj_to_img[o_index + 1]:
                    boxes_pred_no_image.append(boxes_pred[o_index])
                    boxes_gt_no_image.append(boxes[o_index])
            boxes_pred_no_image = torch.stack(boxes_pred_no_image)
            boxes_gt_no_image = torch.stack(boxes_gt_no_image)

            iou, bigger_05, bigger_03 = jaccard(boxes_pred_no_image,
                                                boxes_gt_no_image)
            total_iou += iou
            r_05 += bigger_05
            r_03 += bigger_03
            total_boxes += boxes_pred_no_image.size(0)
            imgs_pred = imagenet_deprocess_batch(imgs_pred)

            obj_data = [objs, boxes_pred, masks_pred]
            _, obj_data = split_graph_batch(triples, obj_data, obj_to_img,
                                            triple_to_img)
            objs, boxes_pred, masks_pred = obj_data

            obj_data_gt = [boxes.data]
            if masks is not None:
                obj_data_gt.append(masks.data)
            triples, obj_data_gt = split_graph_batch(triples, obj_data_gt,
                                                     obj_to_img, triple_to_img)
            layouts_3d = one_hot_to_rgb(layout, colors, num_objs)
            for i in range(imgs_pred.size(0)):
                img_filename = '%04d.png' % img_idx
                if args.save_gt_imgs:
                    img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
                    img_gt_path = os.path.join(gt_img_dir, img_filename)
                    imsave(img_gt_path, img_gt)
                if args.save_layout:
                    layout_3d = layouts_3d[i].numpy().transpose(1, 2, 0)
                    layout_path = os.path.join(layout_dir, img_filename)
                    imsave(layout_path, layout_3d)

                img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
                img_path = os.path.join(img_dir, img_filename)
                imsave(img_path, img_pred_np)

                if args.save_graphs:
                    graph_img = draw_scene_graph(objs[i], triples[i], vocab)
                    graph_path = os.path.join(graph_dir, img_filename)
                    imsave(graph_path, graph_img)

                img_idx += 1

            print('Saved %d images' % img_idx)
        avg_iou = total_iou / total_boxes
        print('avg_iou {}'.format(avg_iou.item()))
        print('r0.5 {}'.format(r_05 / total_boxes))
        print('r0.3 {}'.format(r_03 / total_boxes))
        if accuracy_model is not None:
            print('Accuracy {}'.format(corrects / real_objects_count))
Beispiel #7
0
 def forward(self, imgs, objs, boxes, obj_to_img):
     crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
     real_scores, ac_loss = self.discriminator(crops, objs)
     return real_scores, ac_loss, crops
def test_model(model, val_dataloader, input_shape=224, out_path=None):
    since = time.time()

    device = 'cuda'

    model.train(False)  # Set model to evaluate mode
    dataloader = val_dataloader

    running_corrects = 0
    objects_len = 0


    with open(out_path, mode="w", encoding="utf-8", newline="") as csvfile:
        writer = csv.writer(csvfile)
        columns = ['label'] + ['pred'] + ["logits"+str(i+1) for i in range(172)]
        writer.writerow(columns)
        
        # Iterate over data.
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            # get the inputs
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = data
            imgs = imgs.to(device)
            boxes = boxes.to(device)
            obj_to_img = obj_to_img.to(device)
            labels = objs.to(device)

            objects_len += obj_to_img.shape[0]

            with torch.no_grad():
                crops = crop_bbox_batch(imgs, boxes, obj_to_img, input_shape)

            # forward
            outputs = model(crops)
            if type(outputs) == tuple:
                outputs, _ = outputs
            outputs = F.softmax(outputs, dim=-1)
           
            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds.view(-1,1) == labels.view(-1,1))
            
            ## save logits
            labels = labels.detach().cpu().numpy().reshape(-1,1)
            outputs = outputs.detach().cpu().numpy()
            preds = preds.detach().cpu().numpy().reshape(-1,1)
            
            writer.writerows(np.concatenate((labels,preds,outputs), axis=1).tolist())
            


    epoch_acc = running_corrects.item() / objects_len

    print('{} Acc: {:.4f}'.format("val", epoch_acc))

  
    time_elapsed = time.time() - since
    print('Validation complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    return model