示例#1
0
def init_protonet(opt, pretrained_file= "", pretrained = False):
    '''
    Initialize the ProtoNet
    '''
    device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
    model = ProtoNet().to(device)
    if(pretrained):
        model.load_state_dict(torch.load(pretrained_file))
        print("Loaded pre-trained model")
    return model
    def predict_batch(self, img_paths_list, top_k):
        # load inference samples
        infer_imgs = list()
        for path in img_paths_list:
            infer_imgs.append(torch.tensor(load_img(path)))  # list of tensor
        X = torch.stack(infer_imgs)

        # load model
        model = ProtoNet().cpu()
        model.load_state_dict(torch.load(self.model_path, map_location='cpu'))
        model.eval()

        # start inferring
        pred_label_list = list()
        pred_class_name = list()
        pred_class_sku = list()
        pred_class_prob = list()

        model_output = model(X)  # [batch_size,128]
        dists = euclidean_dist(
            model_output.to('cpu'),
            self.prototypes.to('cpu'))  # [batch_size,num_classes]
        dists = dists.data.cpu().numpy()
        sorted_dists = np.sort(dists, axis=1)
        sorted_idxs = np.argsort(dists, axis=1)
        # whether reject
        threshold = 15.0
        mask = sorted_dists < threshold

        for i in range(len(infer_imgs)):
            pred_class_prob.append(sorted_dists[i][mask[i]][:top_k].tolist())
            pred_label_list.append(
                self.labels[sorted_idxs[i]][mask[i]][:top_k].tolist())
            pred_class_sku.append(
                [self.idx2sku[idx] for idx in pred_label_list[i]])
            pred_class_name.append(
                [self.sku2name[idx] for idx in pred_class_sku[i]])

        result = []  # list of dict for each image
        for i in range(len(infer_imgs)):
            cur_img_result = {
                'name': pred_class_name[i],
                'prob': pred_class_prob[i],
                'sku': pred_class_sku[i]
            }
            result.append(cur_img_result)

        return result
    def retrain(self, img_paths_list, class_name, sku):

        self.labelID += 1

        infer_imgs = []
        for p in img_paths_list:
            infer_imgs += [
                transforms.ToTensor()(im) for im in image_enforce(p)
            ]
        X = torch.stack(infer_imgs)

        # load model
        model = ProtoNet().cpu()
        model.load_state_dict(torch.load(self.model_path, map_location='cpu'))
        model.eval()

        # compute new prototype
        model_output = model(X)  # [batch_size,128]
        batch_prototype = model_output.mean(0)
        batch_prototype = batch_prototype.unsqueeze(0)

        # whether fail to map to a distinguishing emmbedding
        threshold = 0.0
        dists = euclidean_dist(
            batch_prototype.to('cpu'),
            self.prototypes.to('cpu'))  # [batch_size,num_classes]
        min_dist = torch.min(dists).item()
        if min_dist < threshold:
            index = np.argmin(dists)
            sim_lblid = self.labels[index]
            info = {
                'msg': 'fail',
                'similar_object_name': self.sku2name[self.idx2sku[sim_lblid]],
                'similar_object_sku': self.idx2sku[sim_lblid]
            }
            return info

        # add new class info
        self.prototypes = torch.cat([self.prototypes, batch_prototype], 0)
        self.labels = np.concatenate((self.labels, [self.labelID]), axis=0)
        self.idx2sku[self.labelID] = sku
        self.sku2name[sku] = class_name

        info = {'msg': 'success'}
        return info
示例#4
0

# PARAMS
opts = get_basic_parser(get_parser()).parse_args()
opts.method = 'proto'
setup(opts)

# CREATE MODEL
net = ProtoNet().to(opts.device)

# RESUME (fixme with appropriate epoch and iter)
if os.path.exists(opts.model_file):
    print_log(
        'loading previous best checkpoint [{}] ...'.format(opts.model_file),
        opts.log_file)
    net.load_state_dict(torch.load(opts.model_file))

if opts.multi_gpu:
    print_log('Wrapping network into multi-gpu mode ...', opts.log_file)
    net = torch.nn.DataParallel(net)

# PREPARE DATA
train_db, val_db, test_db, _ = data_loader(opts)

# MISC
# TODO: original repo don't have weight decay
optimizer = optim.Adam(net.parameters(),
                       lr=opts.lr,
                       weight_decay=opts.weight_decay)
# scheduler = MultiStepLR(optimizer, milestones=opts.scheduler, gamma=opts.lr_scheduler_gamma)
scheduler = StepLR(optimizer,
示例#5
0
    N_query = args.N_query
    model_path = args.load
    test_csv = args.test_csv
    test_data_dir = args.test_data_dir
    testcase_csv = args.testcase_csv
    output_csv = args.output_csv

    test_dataset = MiniDataset(test_csv, test_data_dir)
    test_loader = DataLoader(test_dataset,
                             batch_size=N_way * (N_query + N_shot),
                             num_workers=3,
                             pin_memory=False,
                             worker_init_fn=worker_init_fn,
                             sampler=GeneratorSampler(testcase_csv))

    model = ProtoNet()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.load_state_dict((torch.load(model_path, map_location=device)))

    prediction_results = predict(model, test_loader, N_way, N_shot)

    row_names = ["query{}".format(i) for i in range(1, 76)]
    row_names = ["episode_id"] + row_names

    with open(output_csv, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(row_names)

        for i, pred in enumerate(prediction_results):
            row = [i] + list(pred.numpy())
            writer.writerow(row)
示例#6
0
import pickle


def get_embd(x, model):
    '''
    Test the model trained with the prototypical learning algorithm
    '''
    #    device = 'cuda:0' if torch.cuda.is_available() and opt.cuda else 'cpu'
    #     x, y = x.to(device), y.to(device)
    out = model(x)
    return (out)


model = ProtoNet()
model_path = '/home/pallav_soni/pro/output/best_model.pth'
model.load_state_dict(torch.load(model_path))


def load_img(path):
    x = Image.open(path).convert('RGB')
    x = x.resize((28, 28))
    shape = 3, x.size[0], x.size[1]
    x = np.array(x, np.float32, copy=False)
    x = 1.0 - torch.from_numpy(x)
    x = x.transpose(0, 1).contiguous().view(shape)
    x = torch.unsqueeze(x, 0)
    return x


#img = load_img('/home/pallav_soni/dumm.jpeg')
#img = torch.unsqueeze(img, 0)