示例#1
0
    def execute(self, point_cloud, label):
        B, D, N = point_cloud.size()
        trans = self.stn(point_cloud)
        point_cloud = point_cloud.transpose(0, 2, 1)
        point_cloud = nn.bmm(point_cloud, trans)

        point_cloud = point_cloud.transpose(0, 2, 1)

        out1 = self.relu(self.bn1(self.conv1(point_cloud)))
        out2 = self.relu(self.bn2(self.conv2(out1)))
        out3 = self.relu(self.bn3(self.conv3(out2)))

        trans_feat = self.fstn(out3)
        x = out3.transpose(0, 2, 1)
        net_transformed = nn.bmm(x, trans_feat)
        net_transformed = net_transformed.transpose(0, 2, 1)

        out4 = self.relu(self.bn4(self.conv4(net_transformed)))
        out5 = self.bn5(self.conv5(out4))
        out_max = jt.argmax(out5, 2, keepdims=True)[1]
        out_max = out_max.view(-1, 2048)

        out_max = concat((out_max, label), 1)
        expand = out_max.view(-1, 2048 + 16, 1).repeat(1, 1, N)
        concat_feature = concat([expand, out1, out2, out3, out4, out5], 1)
        net = self.relu(self.bns1(self.convs1(concat_feature)))
        net = self.relu(self.bns2(self.convs2(net)))
        net = self.relu(self.bns3(self.convs3(net)))
        net = self.convs4(net)
        return net
示例#2
0
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        y_ = data.y[mask]
        tmp = []
        for i in range(mask.shape[0]):
            if mask[i] == True:
                tmp.append(logits[i])
        logits_ = jt.stack(tmp)
        pred, _ = jt.argmax(logits_, dim=1)
        acc = pred.equal(y_).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs
def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
    """
    Strategy: suppose that the models that we will create will have prefixes appended
    to each of its keys, for example due to an extra level of nesting that the original
    pre-trained weights from ImageNet won't contain. For example, model.state_dict()
    might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
    res2.conv1.weight. We thus want to match both parameters together.
    For that, we look for each model weight, look among all loaded keys if there is one
    that is a suffix of the current weight name, and use it if that's the case.
    If multiple matches exist, take the one with longest size
    of the corresponding name. For example, for the same model as before, the pretrained
    weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
    we want to match backbone[0].body.conv1.weight to conv1.weight, and
    backbone[0].body.res2.conv1.weight to res2.conv1.weight.
    """
    current_keys = sorted(list(model_state_dict.keys()))
    loaded_keys = sorted(list(loaded_state_dict.keys()))
    # get a matrix of string matches, where each (i, j) entry correspond to the size of the
    # loaded_key string, if it matches
    match_matrix = [
        len(j) if i.endswith(j) else 0 for i in current_keys
        for j in loaded_keys
    ]
    match_matrix = jt.array(match_matrix).reshape(len(current_keys),
                                                  len(loaded_keys))
    idxs, max_match_size = jt.argmax(match_matrix, 1)
    # remove indices that correspond to no-match
    idxs[max_match_size == 0] = -1

    # used for logging
    max_size = max([len(key) for key in current_keys]) if current_keys else 1
    max_size_loaded = max([len(key)
                           for key in loaded_keys]) if loaded_keys else 1
    log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
    logger = logging.getLogger(__name__)
    for idx_new, idx_old in enumerate(idxs.numpy().tolist()):
        if idx_old == -1:
            continue
        key = current_keys[idx_new]
        key_old = loaded_keys[idx_old]
        model_state_dict[key] = loaded_state_dict[key_old]
        logger.info(
            log_str_template.format(
                key,
                max_size,
                key_old,
                max_size_loaded,
                tuple(loaded_state_dict[key_old].shape),
            ))
示例#4
0
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, C]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    #import ipdb; ipdb.set_trace()
    #device = xyz.device
    B, N, C = xyz.shape
    centroids = jt.zeros((B, npoint))
    distance = jt.ones((B, N)) * 1e10

    farthest = np.random.randint(0, N, B, dtype='l')
    batch_indices = np.arange(B, dtype='l')
    farthest = jt.array(farthest)
    batch_indices = jt.array(batch_indices)
    # jt.sync_all(True)
    # print (xyz.shape, farthest.shape, batch_indices.shape, centroids.shape, distance.shape)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :]
        centroid = centroid.view(B, 1, 3)

        dist = jt.sum((xyz - centroid.repeat(1, N, 1))**2, 2)
        mask = dist < distance
        # distance = mask.ternary(distance, dist)
        # print (mask.size())

        if mask.sum().data[0] > 0:
            distance[mask] = dist[mask]  # bug if mask.sum() == 0

        farthest = jt.argmax(distance, 1)[0]
        # print (farthest)
        # print (farthest.shape)
    # B, N, C = xyz.size()
    # sample_list = random.sample(range(0, N), npoint)
    # centroids = jt.zeros((1, npoint))
    # centroids[0,:] = jt.array(sample_list)
    # centroids = centroids.view(1, -1).repeat(B, 1)
    # x_center = x[:,sample_list, :]
    return centroids
示例#5
0
def test(net, test_data, state):
    net.eval()
    loss_avg = 0.0
    correct = 0
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(test_data):
        data, target = jt.array(data), jt.array(target)

        # forward
        output = net(data)
        loss = jt.nn.cross_entropy_loss(output, target)

        # accuracy
        pred = jt.argmax(output, dim=1)[0]
        correct += float(jt.sum(pred == target).data[0])

        # test loss average
        loss_avg += float(loss.data[0])
    end_time = time.time()
    fps = (len(test_data) * test_data.batch_size) / (end_time - start_time)

    state['test_loss'] = loss_avg / len(test_data)
    state['test_accuracy'] = correct / (len(test_data) * test_data.batch_size)
    state['test_fps'] = fps
示例#6
0
 def log_prob(self, x):
     x = jt.argmax(x, dim=-1)[0]
     return Categorical.log_prob(self, x)
示例#7
0
def train(model):
    batch_size = 16
    train_loader = ShapeNetPart(partition='trainval',
                                num_points=2048,
                                class_choice=None,
                                batch_size=batch_size,
                                shuffle=True)
    test_loader = ShapeNetPart(partition='test',
                               num_points=2048,
                               class_choice=None,
                               batch_size=batch_size,
                               shuffle=False)

    seg_num_all = 50
    seg_start_index = 0

    print(str(model))
    base_lr = 0.01
    optimizer = nn.SGD(model.parameters(),
                       lr=base_lr,
                       momentum=0.9,
                       weight_decay=1e-4)
    lr_scheduler = LRScheduler(optimizer, base_lr)

    # criterion = nn.cross_entropy_loss() # here

    best_test_iou = 0
    for epoch in range(200):
        ####################
        # Train
        ####################
        lr_scheduler.step(len(train_loader) * batch_size)
        train_loss = 0.0
        count = 0.0
        model.train()
        train_true_cls = []
        train_pred_cls = []
        train_true_seg = []
        train_pred_seg = []
        train_label_seg = []

        # debug = 0
        for data, label, seg in train_loader:
            # with jt.profile_scope() as report:

            seg = seg - seg_start_index
            label_one_hot = np.zeros((label.shape[0], 16))
            # print (label.size())
            for idx in range(label.shape[0]):
                label_one_hot[idx, label.numpy()[idx, 0]] = 1
            label_one_hot = jt.array(label_one_hot.astype(np.float32))
            data = data.permute(0, 2,
                                1)  # for pointnet it should not be committed
            batch_size = data.size()[0]
            # print ('input data shape')
            # print (data.shape, label_one_hot.shape)
            # for pointnet b c n for pointnet2 b n c

            seg_pred = model(data, label_one_hot)
            seg_pred = seg_pred.permute(0, 2, 1)
            # print (seg_pred.size())
            # print (seg_pred.size(), seg.size())
            loss = nn.cross_entropy_loss(seg_pred.view(-1, seg_num_all),
                                         seg.view(-1))
            # print (loss.data)
            optimizer.step(loss)

            pred = jt.argmax(seg_pred, dim=2)[0]  # (batch_size, num_points)
            # print ('pred size =', pred.size(), seg.size())
            count += batch_size
            train_loss += loss.numpy() * batch_size
            seg_np = seg.numpy()  # (batch_size, num_points)
            pred_np = pred.numpy()  # (batch_size, num_points)
            # print (type(label))

            label = label.numpy()  # added

            train_true_cls.append(
                seg_np.reshape(-1))  # (batch_size * num_points)
            train_pred_cls.append(
                pred_np.reshape(-1))  # (batch_size * num_points)
            train_true_seg.append(seg_np)
            train_pred_seg.append(pred_np)
            temp_label = label.reshape(-1, 1)

            train_label_seg.append(temp_label)

            # print(report)
        train_true_cls = np.concatenate(train_true_cls)
        train_pred_cls = np.concatenate(train_pred_cls)
        # print (train_true_cls.shape ,train_pred_cls.shape)
        train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)

        avg_per_class_acc = metrics.balanced_accuracy_score(
            train_true_cls.data, train_pred_cls.data)
        # print ('train acc =',train_acc, 'avg_per_class_acc', avg_per_class_acc)
        train_true_seg = np.concatenate(train_true_seg, axis=0)
        # print (len(train_pred_seg), train_pred_seg[0].shape)
        train_pred_seg = np.concatenate(train_pred_seg, axis=0)
        # print (len(train_label_seg), train_label_seg[0].size())
        # print (train_label_seg[0])
        train_label_seg = np.concatenate(train_label_seg, axis=0)
        # print (train_pred_seg.shape, train_true_seg.shape, train_label_seg.shape)
        train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg,
                                         train_label_seg, None)
        outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (
            epoch, train_loss * 1.0 / count, train_acc, avg_per_class_acc,
            np.mean(train_ious))
        # io.cprint(outstr)
        print(outstr)
        ####################
        # Test
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()
        test_true_cls = []
        test_pred_cls = []
        test_true_seg = []
        test_pred_seg = []
        test_label_seg = []
        for data, label, seg in test_loader:
            seg = seg - seg_start_index
            label_one_hot = np.zeros((label.shape[0], 16))
            for idx in range(label.shape[0]):
                label_one_hot[idx, label.numpy()[idx, 0]] = 1
            label_one_hot = jt.array(label_one_hot.astype(np.float32))
            data = data.permute(0, 2, 1)  # for pointnet should not be commit
            batch_size = data.size()[0]
            seg_pred = model(data, label_one_hot)
            seg_pred = seg_pred.permute(0, 2, 1)
            loss = nn.cross_entropy_loss(seg_pred.view(-1, seg_num_all),
                                         seg.view(-1, 1).squeeze(-1))
            pred = jt.argmax(seg_pred, dim=2)[0]
            count += batch_size
            test_loss += loss.numpy() * batch_size
            seg_np = seg.numpy()
            pred_np = pred.numpy()
            label = label.numpy()  # added

            test_true_cls.append(seg_np.reshape(-1))
            test_pred_cls.append(pred_np.reshape(-1))
            test_true_seg.append(seg_np)
            test_pred_seg.append(pred_np)
            test_label_seg.append(label.reshape(-1, 1))
        test_true_cls = np.concatenate(test_true_cls)
        test_pred_cls = np.concatenate(test_pred_cls)
        test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
        avg_per_class_acc = metrics.balanced_accuracy_score(
            test_true_cls, test_pred_cls)
        test_true_seg = np.concatenate(test_true_seg, axis=0)
        test_pred_seg = np.concatenate(test_pred_seg, axis=0)
        test_label_seg = np.concatenate(test_label_seg)
        test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg,
                                        test_label_seg, None)
        outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (
            epoch, test_loss * 1.0 / count, test_acc, avg_per_class_acc,
            np.mean(test_ious))
        print(outstr)
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    batches = next(iter(val_dataloader))

    real_A = batches[0]
    real_A_eyel = batches[1]
    real_A_eyer = batches[2]
    real_A_nose = batches[3]
    real_A_mouth = batches[4]
    real_A_hair = batches[5]
    real_A_bg = batches[6]
    mask = batches[7]
    mask2 = batches[8]
    center = batches[9]
    cmaskel = batches[10]
    cmasker = batches[11]
    cmaskno = batches[12]
    cmaskmo = batches[13]

    maskh = mask * mask2
    maskb = inverse_mask(mask2)

    fake_B0 = G_global(real_A)

    # EYES, NOSE, MOUTH
    fake_B_eyel1 = G_l_eyel(real_A_eyel)
    fake_B_eyel2 = ae_eyel(fake_B_eyel1)
    fake_B_eyel = add_with_mask(fake_B_eyel2, fake_B_eyel1, cmaskel)
    fake_B_eyer1 = G_l_eyer(real_A_eyer)
    fake_B_eyer2 = ae_eyer(fake_B_eyer1)
    fake_B_eyer = add_with_mask(fake_B_eyer2, fake_B_eyer1, cmasker)
    fake_B_nose1 = G_l_nose(real_A_nose)
    fake_B_nose2 = ae_nose(fake_B_nose1)
    fake_B_nose = add_with_mask(fake_B_nose2, fake_B_nose1, cmaskno)
    fake_B_mouth1 = G_l_mouth(real_A_mouth)
    outputs1 = CLm(real_A_mouth)
    pred = jt.argmax(outputs1, dim=1)[0]
    fake_B_mouth2w = ae_mowhite(fake_B_mouth1)
    fake_B_mouth2b = ae_moblack(fake_B_mouth1)
    fake_B_mouth2s = jt.contrib.concat((fake_B_mouth2w, fake_B_mouth2b), 1)
    bs, c, h, w = fake_B_mouth2s.shape
    index = pred + jt.arange(bs) * c
    fake_B_mouth2 = fake_B_mouth2s.reshape([-1, h,
                                            w])[index].reshape([bs, 1, h, w])
    fake_B_mouth = add_with_mask(fake_B_mouth2, fake_B_mouth1, cmaskmo)
    # HAIR & BG
    outputs2 = CLh(real_A_hair)
    onehot2 = getonehot(outputs2, 3, bs)
    fake_B_hair = G_l_hair(real_A_hair, onehot2)
    fake_B_bg = G_l_bg(real_A_bg)
    # PARTCOMBINE
    fake_B1 = partCombiner2_bg(center,
                               fake_B_eyel,
                               fake_B_eyer,
                               fake_B_nose,
                               fake_B_mouth,
                               fake_B_hair,
                               fake_B_bg,
                               maskh,
                               maskb,
                               comb_op=1,
                               load_h=opt.img_height,
                               load_w=opt.img_width)
    # FUSION NET
    fake_B = G_combine(jt.contrib.concat((fake_B0, fake_B1), 1))

    img_sample = np.concatenate(
        [real_A.data, fake_B.repeat(1, 3, 1, 1).data], -2)
    save_image(img_sample,
               "images/%s/%s.jpg" % (opt.dataset_name, batches_done),
               nrow=5)
        fake = jt.zeros([real_A.shape[0], 1]).stop_grad()

        fake_B0 = G_global(real_A)
        # EYES, NOSE, MOUTH
        fake_B_eyel1 = G_l_eyel(real_A_eyel)
        fake_B_eyel2 = ae_eyel(fake_B_eyel1)
        fake_B_eyel = add_with_mask(fake_B_eyel2, fake_B_eyel1, cmaskel)
        fake_B_eyer1 = G_l_eyer(real_A_eyer)
        fake_B_eyer2 = ae_eyer(fake_B_eyer1)
        fake_B_eyer = add_with_mask(fake_B_eyer2, fake_B_eyer1, cmasker)
        fake_B_nose1 = G_l_nose(real_A_nose)
        fake_B_nose2 = ae_nose(fake_B_nose1)
        fake_B_nose = add_with_mask(fake_B_nose2, fake_B_nose1, cmaskno)
        fake_B_mouth1 = G_l_mouth(real_A_mouth)
        outputs1 = CLm(real_A_mouth)
        pred = jt.argmax(outputs1, dim=1)[0]
        fake_B_mouth2w = ae_mowhite(fake_B_mouth1)
        fake_B_mouth2b = ae_moblack(fake_B_mouth1)
        fake_B_mouth2s = jt.contrib.concat((fake_B_mouth2w, fake_B_mouth2b), 1)
        bs, c, h, w = fake_B_mouth2s.shape
        index = pred + jt.arange(bs) * c
        fake_B_mouth2 = fake_B_mouth2s.reshape([-1, h, w])[index].reshape(
            [bs, 1, h, w])
        fake_B_mouth = add_with_mask(fake_B_mouth2, fake_B_mouth1, cmaskmo)
        # HAIR & BG
        outputs2 = CLh(real_A_hair)
        onehot2 = getonehot(outputs2, 3, bs)
        fake_B_hair = G_l_hair(real_A_hair, onehot2)
        fake_B_bg = G_l_bg(real_A_bg)
        # PARTCOMBINE
        fake_B1 = partCombiner2_bg(center,
示例#10
0
def getonehot(outputs, classes, batch_size):
    index = jt.argmax(outputs,1)
    y = jt.unsqueeze(index[0],1)
    onehot = jt.zeros([batch_size, classes])
    onehot.scatter_(1, y, jt.array(1.))
    return onehot 
示例#11
0
 def log_prob(self, x):
     if len(x.shape) == 1:
         x = x.unsqueeze(0)
     logits = self.logits.broadcast(x.shape)
     indices = jt.argmax(x, dim=-1)[0]
     return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)