Exemple #1
0
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    # generate a batch of data size  < 64
    batch_index = 1
    if len(batch) == 0:
        return None

    cur_batch_size = len(batch[data_features.index('total_parts_cnt')])
    total_part_cnt = batch[data_features.index('total_parts_cnt')][0]

    if total_part_cnt == 1:
        print('passed an entire shape does not work for batch norm')
        return None
    input_total_part_cnt = batch[data_features.index('total_parts_cnt')][0]                             # 1
    input_img = batch[data_features.index('img')][0]                                                    # 3 x H x W
    input_img = input_img.repeat(input_total_part_cnt, 1, 1, 1)                            # part_cnt 3 x H x W
    input_pts = batch[data_features.index('pts')][0].squeeze(0)[:input_total_part_cnt]                             # part_cnt x N x 3
    input_ins_one_hot = batch[data_features.index('ins_one_hot')][0].squeeze(0)[:input_total_part_cnt]             # part_cnt x max_similar_parts
    input_similar_part_cnt = batch[data_features.index('similar_parts_cnt')][0].squeeze(0)[:input_total_part_cnt]  # part_cnt x 1    
    input_box_size = batch[data_features.index('box_size')][0].squeeze(0)[:input_total_part_cnt]

    # prepare gt: 
    gt_mask = (batch[data_features.index('mask')][0].squeeze(0)[:input_total_part_cnt].to(conf.device),)  
    input_total_part_cnt = [batch[data_features.index('total_parts_cnt')][0]]
    while total_part_cnt < 32 and batch_index < cur_batch_size:
        cur_input_cnt = batch[data_features.index('total_parts_cnt')][batch_index]
        total_part_cnt += cur_input_cnt
        if total_part_cnt > 40:
            total_part_cnt -= cur_input_cnt
            batch_index += 1
            continue
        cur_batch_img = batch[data_features.index('img')][batch_index].repeat(cur_input_cnt, 1, 1, 1)
        input_img = torch.cat((input_img, cur_batch_img), dim=0)
        cur_box_size = batch[data_features.index('box_size')][batch_index].squeeze(0)[:cur_input_cnt]
        input_box_size = torch.cat( (input_box_size, cur_box_size), dim=0)   
        input_pts = torch.cat((input_pts, batch[data_features.index('pts')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)                            # B x max_parts x N x 3
        input_ins_one_hot = torch.cat((input_ins_one_hot, batch[data_features.index('ins_one_hot')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)    # B x max_parts x max_similar_parts
        input_total_part_cnt.append(batch[data_features.index('total_parts_cnt')][batch_index])                             # 1
        input_similar_part_cnt = torch.cat((input_similar_part_cnt, batch[data_features.index('similar_parts_cnt')][batch_index].squeeze(0)[:cur_input_cnt]), dim=0)  # B x max_parts x 2    
        # prepare gt
        gt_mask = gt_mask + (batch[data_features.index('mask')][batch_index].squeeze(0)[:cur_input_cnt].to(conf.device), )
        batch_index += 1

    input_img = input_img.to(conf.device); input_pts = input_pts.to(conf.device); # input_sem_one_hot = input_sem_one_hot.to(conf.device); 
    input_similar_part_cnt = input_similar_part_cnt.to(conf.device); input_ins_one_hot = input_ins_one_hot.to(conf.device)
    input_box_size = input_box_size.to(conf.device)
    batch_size = input_img.shape[0]
    num_point = input_pts.shape[1]

    # forward through the network
    pred_masks = network(input_img - 0.5, input_pts, input_ins_one_hot, input_total_part_cnt)
    # perform matching and calculate masks 
    mask_loss_per_data = []; t = 0;
    matched_pred_mask_all = torch.zeros(batch_size, 224, 224); matched_gt_mask_all = torch.zeros(batch_size, 224, 224) 
    for i in range(len(input_total_part_cnt)):
        total_cnt = input_total_part_cnt[i]
        matched_gt_ids, matched_pred_ids = network.linear_assignment(gt_mask[i], pred_masks[i][:-1, :,:], input_similar_part_cnt[t:t+total_cnt])
        

        # select the matched data
        matched_pred_mask = pred_masks[i][matched_pred_ids]
        matched_gt_mask = gt_mask[i][matched_gt_ids]

        matched_gt_mask_all[t:t+total_cnt, :, :] = matched_gt_mask
        matched_pred_mask_all[t:t+total_cnt, :, :] = matched_pred_mask

        # for computing mask soft iou loss
        matched_mask_loss = network.get_mask_loss(matched_pred_mask, matched_gt_mask)

        mask_loss_per_data.append(matched_mask_loss.mean())
        t+= total_cnt
    mask_loss_per_data = torch.stack(mask_loss_per_data)
    
    # for each type of loss, compute avg loss per batch
    mask_loss = mask_loss_per_data.mean()

    # compute total loss
    total_loss = mask_loss * conf.loss_weight_mask

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{mask_loss.item():>10.5f}'''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('mask_loss', mask_loss.item(), step)
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_img_dir = os.path.join(out_dir, 'input_img')
            input_pts_dir = os.path.join(out_dir, 'input_pts')
            gt_mask_dir = os.path.join(out_dir, 'gt_mask')
            pred_mask_dir = os.path.join(out_dir, 'pred_mask')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_img_dir)
                os.mkdir(input_pts_dir)
                os.mkdir(gt_mask_dir)
                os.mkdir(pred_mask_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')

                t = 0
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)

                    cur_input_img = (input_img[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    Image.fromarray(cur_input_img).save(os.path.join(input_img_dir, fn))
                    cur_input_pts = input_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(BASE_DIR, input_pts_dir, fn), cur_input_pts, blender_fn='object_centered.blend')
                    cur_gt_mask = (matched_gt_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_gt_mask).save(os.path.join(gt_mask_dir, fn))
                    cur_pred_mask = (matched_pred_mask_all[i].cpu().numpy() > 0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_pred_mask).save(os.path.join(pred_mask_dir, fn))
                
            if batch_ind == conf.num_batch_every_visu - 1:
                # visu html
                utils.printout(conf.flog, 'Generating html visualization ...')
                sublist = 'input_img,input_pts,gt_mask,pred_mask,info'
                cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (out_dir, os.path.join(BASE_DIR, '../utils/gen_html_hierachy_local.py'), sublist, sublist)
                call(cmd, shell=True)
                utils.printout(conf.flog, 'DONE')

    return total_loss
Exemple #2
0
                                               2] = (pred_shape_mask[j] * 255)
                    if len(gt_inds) != 0:
                        gt_mask_to_vis[gt_inds[:, 0],
                                       gt_inds[:,
                                               1], :] = (np.array(color) * 255)
                        cur_gt_shape_mask_to_vis[gt_inds[:, 0],
                                                 gt_inds[:, 1], :] = (
                                                     np.array(color) * 255)

                    #if len(pred_inds) != 0:
                    #pred_mask_to_vis[pred_inds[:,0], pred_inds[:,1], :] = (np.array(color) * 255)
                    #cur_pred_shape_mask_to_vis[pred_inds[:,0], pred_inds[:,1], :] = (np.array(color) * 255)

                    cur_input_pts = cur_shape_input_pts[j]
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, child_input_pts_dir, child_fn),
                                            cur_input_pts,
                                            blender_fn='object_centered.blend')
                    cur_gt_mask = cur_gt_shape_mask_to_vis.astype(np.uint8)
                    Image.fromarray(cur_gt_mask).save(
                        os.path.join(child_gt_mask_dir, child_fn + '.png'))
                    cur_pred_mask = cur_pred_shape_mask_to_vis.astype(np.uint8)
                    Image.fromarray(cur_pred_mask).save(
                        os.path.join(child_pred_mask_dir, child_fn + '.png'))
                    cur_pred_pts = pred_shape_to_vis[j]
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, child_pred_pose_dir, child_fn),
                                            cur_pred_pts,
                                            blender_fn='camera_centered.blend')
                    cur_gt_pts = gt_shape_to_vis[j]
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, child_gt_pose_dir, child_fn),
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    input_pcs = torch.cat(batch[data_features.index('pcs')],
                          dim=0).to(conf.device)  # B x 3N x 3
    input_pxids = torch.cat(batch[data_features.index('pc_pxids')],
                            dim=0).to(conf.device)  # B x 3N x 2
    input_movables = torch.cat(batch[data_features.index('pc_movables')],
                               dim=0).to(conf.device)  # B x 3N
    batch_size = input_pcs.shape[0]

    input_pcid1 = torch.arange(batch_size).unsqueeze(1).repeat(
        1, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcid2 = furthest_point_sample(
        input_pcs, conf.num_point_per_shape).long().reshape(-1)  # BN
    input_pcs = input_pcs[input_pcid1,
                          input_pcid2, :].reshape(batch_size,
                                                  conf.num_point_per_shape, -1)
    input_pxids = input_pxids[input_pcid1,
                              input_pcid2, :].reshape(batch_size,
                                                      conf.num_point_per_shape,
                                                      -1)
    input_movables = input_movables[input_pcid1, input_pcid2].reshape(
        batch_size, conf.num_point_per_shape)

    input_dirs1 = torch.cat(
        batch[data_features.index('gripper_direction_camera')],
        dim=0).to(conf.device)  # B x 3
    input_dirs2 = torch.cat(
        batch[data_features.index('gripper_forward_direction_camera')],
        dim=0).to(conf.device)  # B x 3

    # forward through the network
    pred_result_logits, pred_whole_feats = network(
        input_pcs, input_dirs1, input_dirs2)  # B x 2, B x F x N

    # prepare gt
    gt_result = torch.Tensor(batch[data_features.index('result')]).long().to(
        conf.device)  # B
    gripper_img_target = torch.cat(
        batch[data_features.index('gripper_img_target')],
        dim=0).to(conf.device)  # B x 3 x H x W

    # for each type of loss, compute losses per data
    result_loss_per_data = network.critic.get_ce_loss(pred_result_logits,
                                                      gt_result)

    # for each type of loss, compute avg loss per batch
    result_loss = result_loss_per_data.mean()

    # compute total loss
    total_loss = result_loss

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (
                not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_pc_dir = os.path.join(out_dir, 'input_pc')
            gripper_img_target_dir = os.path.join(out_dir,
                                                  'gripper_img_target')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_pc_dir)
                os.mkdir(gripper_img_target_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, input_pc_dir, fn),
                                            input_pcs[i].cpu().numpy(),
                                            highlight_id=0)
                    cur_gripper_img_target = (
                        gripper_img_target[i].permute(1, 2, 0).cpu().numpy() *
                        255).astype(np.uint8)
                    Image.fromarray(cur_gripper_img_target).save(
                        os.path.join(gripper_img_target_dir, fn))
                    with open(
                            os.path.join(info_dir, fn.replace('.png', '.txt')),
                            'w') as fout:
                        fout.write('cur_dir: %s\n' %
                                   batch[data_features.index('cur_dir')][i])
                        fout.write('pred: %s\n' % utils.print_true_false(
                            (pred_result_logits[i] > 0).cpu().numpy()))
                        fout.write(
                            'gt: %s\n' %
                            utils.print_true_false(gt_result[i].cpu().numpy()))
                        fout.write('result_loss: %f\n' %
                                   result_loss_per_data[i].item())

            if batch_ind == conf.num_batch_every_visu - 1:
                # visu html
                utils.printout(conf.flog, 'Generating html visualization ...')
                sublist = 'input_pc,gripper_img_target,info'
                cmd = 'cd %s && python %s . 10 htmls %s %s > /dev/null' % (
                    out_dir,
                    os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'),
                    sublist, sublist)
                call(cmd, shell=True)
                utils.printout(conf.flog, 'DONE')

    return total_loss, pred_whole_feats.detach(), input_pcs.detach(
    ), input_pxids.detach(), input_movables.detach()
Exemple #4
0
def forward(batch, data_features, network, conf, \
        is_val=False, step=None, epoch=None, batch_ind=0, num_batch=1, start_time=0, \
        log_console=False, log_tb=False, tb_writer=None, lr=None):
    # prepare input
    batch_index = 1
    if len(batch) == 0:
        return None
    cur_batch_size = len(batch[data_features.index('total_parts_cnt')])
    total_part_cnt = batch[data_features.index('total_parts_cnt')][0]
    input_total_part_cnt = batch[data_features.index('total_parts_cnt')][
        0]  # 1
    input_img = batch[data_features.index('img')][0]  # 3 x H x W
    input_img = input_img.repeat(input_total_part_cnt, 1, 1,
                                 1)  # part_cnt 3 x H x W
    input_pts = batch[data_features.index('pts')][0].squeeze(
        0)[:input_total_part_cnt]
    input_ins_one_hot = batch[data_features.index('ins_one_hot')][0].squeeze(
        0)[:input_total_part_cnt]  # part_cnt x max_similar_parts
    input_similar_part_cnt = batch[data_features.index('similar_parts_cnt')][
        0].squeeze(0)[:input_total_part_cnt]  # part_cnt x 1
    input_shape_id = [batch[data_features.index('shape_id')][0]
                      ] * input_total_part_cnt
    input_view_id = [batch[data_features.index('view_id')][0]
                     ] * input_total_part_cnt
    # prepare gt:
    gt_cam_dof = batch[data_features.index('parts_cam_dof')][0].squeeze(
        0)[:input_total_part_cnt]
    gt_mask = [
        batch[data_features.index('mask')][0].squeeze(0)
        [:input_total_part_cnt].to(conf.device)
    ]
    input_total_part_cnt = [batch[data_features.index('total_parts_cnt')][0]]
    input_similar_parts_edge_indices = [
        batch[data_features.index('similar_parts_edge_indices')][0].to(
            conf.device)
    ]
    while total_part_cnt < 70 and batch_index < cur_batch_size:
        cur_input_cnt = batch[data_features.index(
            'total_parts_cnt')][batch_index]
        total_part_cnt += cur_input_cnt
        if total_part_cnt > 90:
            total_part_cnt -= cur_input_cnt
            batch_index += 1
            continue
        cur_batch_img = batch[data_features.index('img')][batch_index].repeat(
            cur_input_cnt, 1, 1, 1)
        input_img = torch.cat((input_img, cur_batch_img), dim=0)
        input_pts = torch.cat((input_pts, batch[data_features.index('pts')]
                               [batch_index].squeeze(0)[:cur_input_cnt]),
                              dim=0)
        input_ins_one_hot = torch.cat(
            (input_ins_one_hot, batch[data_features.index('ins_one_hot')]
             [batch_index].squeeze(0)[:cur_input_cnt]),
            dim=0)  # B x max_parts x max_similar_parts
        input_total_part_cnt.append(
            batch[data_features.index('total_parts_cnt')][batch_index])  # 1
        input_similar_part_cnt = torch.cat(
            (input_similar_part_cnt, batch[data_features.index(
                'similar_parts_cnt')][batch_index].squeeze(0)[:cur_input_cnt]),
            dim=0)  # B x max_parts x 2
        input_shape_id += [
            batch[data_features.index('shape_id')][batch_index]
        ] * cur_input_cnt
        input_view_id += [batch[data_features.index('view_id')][batch_index]
                          ] * cur_input_cnt
        gt_cam_dof = torch.cat((gt_cam_dof, batch[data_features.index(
            'parts_cam_dof')][batch_index].squeeze(0)[:cur_input_cnt]),
                               dim=0)
        # prepare gt
        gt_mask.append(batch[data_features.index('mask')][batch_index].squeeze(
            0)[:cur_input_cnt].to(conf.device))
        input_similar_parts_edge_indices.append(batch[data_features.index(
            'similar_parts_edge_indices')][batch_index].to(conf.device))
        batch_index += 1

    input_img = input_img.to(conf.device)
    input_pts = input_pts.to(conf.device)
    input_similar_part_cnt = input_similar_part_cnt.to(conf.device)
    input_ins_one_hot = input_ins_one_hot.to(conf.device)
    gt_cam_dof = gt_cam_dof.to(conf.device)

    # prepare gt
    gt_center = gt_cam_dof[:, :3]  # B x 3
    gt_quat = gt_cam_dof[:, 3:]  # B x 4
    batch_size = input_img.shape[0]
    num_point = input_pts.shape[1]

    # forward through the network
    pred_masks, pred_center, pred_quat, pred_center2, pred_quat2 = network(
        input_img - 0.5, input_pts, input_ins_one_hot, input_total_part_cnt,
        input_similar_parts_edge_indices)

    mask_loss_per_data = []
    t = 0
    matched_pred_mask_all = []
    matched_gt_mask_all = []
    matched_mask_loss_per_data_all = []
    matched_pred_center_all = []
    matched_gt_center_all = []
    matched_pred_center2_all = []
    matched_pred_quat_all = []
    matched_gt_quat_all = []
    matched_pred_quat2_all = []
    matched_ins_onehot_all = []

    for i in range(len(input_total_part_cnt)):
        total_cnt = input_total_part_cnt[i]
        matched_gt_ids, matched_pred_ids = network.linear_assignment(gt_mask[i], pred_masks[i], \
            input_similar_part_cnt[t:t+total_cnt], input_pts[t:t+total_cnt], gt_center[t:t+total_cnt], \
            gt_quat[t:t+total_cnt], pred_center[t:t+total_cnt], pred_quat[t:t+total_cnt])

        # select the matched data
        matched_pred_mask = pred_masks[i][matched_pred_ids]
        matched_gt_mask = gt_mask[i][matched_gt_ids]

        matched_pred_center = pred_center[t:t + total_cnt][matched_pred_ids]
        matched_pred_center2 = pred_center2[t:t + total_cnt][matched_pred_ids]
        matched_gt_center = gt_center[t:t + total_cnt][matched_gt_ids]

        matched_pred_quat = pred_quat[t:t + total_cnt][matched_pred_ids]
        matched_pred_quat2 = pred_quat2[t:t + total_cnt][matched_pred_ids]
        matched_gt_quat = gt_quat[t:t + total_cnt][matched_gt_ids]

        matched_ins_onehot = input_ins_one_hot[t:t +
                                               total_cnt][matched_pred_ids]
        matched_ins_onehot_all.append(matched_ins_onehot)

        matched_gt_mask_all.append(matched_gt_mask)
        matched_pred_mask_all.append(matched_pred_mask)

        matched_pred_center_all.append(matched_pred_center)
        matched_pred_center2_all.append(matched_pred_center2)
        matched_gt_center_all.append(matched_gt_center)

        matched_pred_quat_all.append(matched_pred_quat)
        matched_pred_quat2_all.append(matched_pred_quat2)
        matched_gt_quat_all.append(matched_gt_quat)

        # for computing mask soft iou loss
        matched_mask_loss_per_data = network.get_mask_loss(
            matched_pred_mask, matched_gt_mask)
        matched_mask_loss_per_data_all.append(matched_mask_loss_per_data)

        mask_loss_per_data.append(matched_mask_loss_per_data.mean())

        t += total_cnt

    matched_ins_onehot_all = torch.cat(matched_ins_onehot_all, dim=0)
    matched_pred_mask_all = torch.cat(matched_pred_mask_all, dim=0)
    matched_gt_mask_all = torch.cat(matched_gt_mask_all, dim=0)
    matched_mask_loss_per_data_all = torch.cat(matched_mask_loss_per_data_all,
                                               dim=0)
    matched_pred_quat_all = torch.cat(matched_pred_quat_all, dim=0)
    matched_pred_quat2_all = torch.cat(matched_pred_quat2_all, dim=0)
    matched_gt_quat_all = torch.cat(matched_gt_quat_all, dim=0)
    matched_pred_center_all = torch.cat(matched_pred_center_all, dim=0)
    matched_pred_center2_all = torch.cat(matched_pred_center2_all, dim=0)
    matched_gt_center_all = torch.cat(matched_gt_center_all, dim=0)

    center_loss_per_data = network.get_center_loss(matched_pred_center_all, matched_gt_center_all) + \
            network.get_center_loss(matched_pred_center2_all, matched_gt_center_all)

    quat_loss_per_data = network.get_quat_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all) + \
            network.get_quat_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all)

    l2_rot_loss_per_data = network.get_l2_rotation_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all) + \
            network.get_l2_rotation_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all)

    whole_shape_cd_per_data = network.get_shape_chamfer_loss(input_pts, matched_pred_quat_all, matched_gt_quat_all, matched_pred_center_all, matched_gt_center_all, input_total_part_cnt) + \
            network.get_shape_chamfer_loss(input_pts, matched_pred_quat2_all, matched_gt_quat_all, matched_pred_center2_all, matched_gt_center_all, input_total_part_cnt)

    # for each type of loss, compute avg loss per batch
    mask_loss_per_data = torch.stack(mask_loss_per_data)
    center_loss = center_loss_per_data.mean()
    quat_loss = quat_loss_per_data.mean()
    mask_loss = mask_loss_per_data.mean()
    l2_rot_loss = l2_rot_loss_per_data.mean()
    shape_chamfer_loss = whole_shape_cd_per_data.mean()

    # compute total loss
    total_loss = \
            center_loss * conf.loss_weight_center + \
            quat_loss * conf.loss_weight_quat + \
            l2_rot_loss * conf.loss_weight_l2_rot + \
            shape_chamfer_loss * conf.loss_weight_shape_chamfer

    # display information
    data_split = 'train'
    if is_val:
        data_split = 'val'

    with torch.no_grad():
        # log to console
        if log_console:
            utils.printout(conf.flog, \
                f'''{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} '''
                f'''{epoch:>5.0f}/{conf.epochs:<5.0f} '''
                f'''{data_split:^10s} '''
                f'''{batch_ind:>5.0f}/{num_batch:<5.0f} '''
                f'''{100. * (1+batch_ind+num_batch*epoch) / (num_batch*conf.epochs):>9.1f}%      '''
                f'''{lr:>5.2E} '''
                f'''{mask_loss.item():>10.5f}'''
                f'''{center_loss.item():>10.5f}'''
                f'''{quat_loss.item():>10.5f}'''
                f'''{l2_rot_loss.item():>10.5f}'''
                f'''{shape_chamfer_loss.item():>10.5f}'''
                f'''{total_loss.item():>10.5f}''')
            conf.flog.flush()

        # log to tensorboard
        if log_tb and tb_writer is not None:
            tb_writer.add_scalar('mask_loss', mask_loss.item(), step)
            tb_writer.add_scalar('center_loss', center_loss.item(), step)
            tb_writer.add_scalar('quat_loss', quat_loss.item(), step)
            tb_writer.add_scalar('l2 rotation_loss', l2_rot_loss.item(), step)
            tb_writer.add_scalar('shape_chamfer_loss',
                                 shape_chamfer_loss.item(), step)
            tb_writer.add_scalar('total_loss', total_loss.item(), step)
            tb_writer.add_scalar('lr', lr, step)

        # gen visu
        if is_val and (
                not conf.no_visu) and epoch % conf.num_epoch_every_visu == 0:
            visu_dir = os.path.join(conf.exp_dir, 'val_visu')
            out_dir = os.path.join(visu_dir, 'epoch-%04d' % epoch)
            input_img_dir = os.path.join(out_dir, 'input_img')
            input_pts_dir = os.path.join(out_dir, 'input_pts')
            gt_mask_dir = os.path.join(out_dir, 'gt_mask')
            pred_mask_dir = os.path.join(out_dir, 'pred_mask')
            gt_dof_dir = os.path.join(out_dir, 'gt_dof')
            pred_dof_dir = os.path.join(out_dir, 'pred_dof')
            pred_dof2_dir = os.path.join(out_dir, 'pred_dof2')
            info_dir = os.path.join(out_dir, 'info')

            if batch_ind == 0:
                # create folders
                os.mkdir(out_dir)
                os.mkdir(input_img_dir)
                os.mkdir(input_pts_dir)
                os.mkdir(gt_mask_dir)
                os.mkdir(pred_mask_dir)
                os.mkdir(gt_dof_dir)
                os.mkdir(pred_dof_dir)
                os.mkdir(pred_dof2_dir)
                os.mkdir(info_dir)

            if batch_ind < conf.num_batch_every_visu:
                utils.printout(conf.flog, 'Visualizing ...')

                # compute pred_pts and gt_pts
                pred_pts = qrot(
                    matched_pred_quat_all.unsqueeze(1).repeat(1, num_point, 1),
                    input_pts) + matched_pred_center_all.unsqueeze(1).repeat(
                        1, num_point, 1)
                pred2_pts = qrot(
                    matched_pred_quat2_all.unsqueeze(1).repeat(
                        1, num_point, 1),
                    input_pts) + matched_pred_center2_all.unsqueeze(1).repeat(
                        1, num_point, 1)
                gt_pts = qrot(
                    matched_gt_quat_all.unsqueeze(1).repeat(1, num_point, 1),
                    input_pts) + matched_gt_center_all.unsqueeze(1).repeat(
                        1, num_point, 1)

                t = 0
                for i in range(batch_size):
                    fn = 'data-%03d.png' % (batch_ind * batch_size + i)

                    cur_input_img = (
                        input_img[i].permute(1, 2, 0).cpu().numpy() *
                        255).astype(np.uint8)
                    Image.fromarray(cur_input_img).save(
                        os.path.join(input_img_dir, fn))
                    cur_input_pts = input_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, input_pts_dir, fn),
                                            cur_input_pts,
                                            blender_fn='object_centered.blend')
                    cur_gt_mask = (matched_gt_mask_all[i].cpu().numpy() >
                                   0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_gt_mask).save(
                        os.path.join(gt_mask_dir, fn))
                    cur_pred_mask = (matched_pred_mask_all[i].cpu().numpy() >
                                     0.5).astype(np.uint8) * 255
                    Image.fromarray(cur_pred_mask).save(
                        os.path.join(pred_mask_dir, fn))
                    cur_pred_pts = pred_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, pred_dof_dir, fn),
                                            cur_pred_pts,
                                            blender_fn='camera_centered.blend')
                    cur_pred_pts = pred2_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, pred_dof2_dir, fn),
                                            cur_pred_pts,
                                            blender_fn='camera_centered.blend')
                    cur_gt_pts = gt_pts[i].cpu().numpy()
                    render_utils.render_pts(os.path.join(
                        BASE_DIR, gt_dof_dir, fn),
                                            cur_gt_pts,
                                            blender_fn='camera_centered.blend')

                    with open(
                            os.path.join(info_dir, fn.replace('.png', '.txt')),
                            'w') as fout:
                        fout.write('shape_id: %s, view_id: %s\n' % (\
                                input_shape_id[i],\
                                input_view_id[i]))
                        fout.write('ins onehot %s\n' %
                                   matched_ins_onehot_all[i])
                        fout.write('mask_loss: %f\n' %
                                   matched_mask_loss_per_data_all[i].item())
                        fout.write('center_loss: %f\n' %
                                   center_loss_per_data[i].item())
                        fout.write('quat_loss: %f\n' %
                                   quat_loss_per_data[i].item())
                        fout.write('l2_rot_loss: %f\n' %
                                   l2_rot_loss_per_data[i].item())
                        fout.write('shape_chamfer_loss %f\n' %
                                   shape_chamfer_loss.item())

    return total_loss