Example #1
0
def train(data_cfg):
    data_options = read_data_cfg(data_cfg)
    m = SegPoseNet(data_options)

    if load_weight_from_path is not None:
        m.load_weights(load_weight_from_path)
        print("Load weights from ", load_weight_from_path)
    i_h = m.height
    i_w = m.width
    o_h = m.output_h
    o_w = m.output_w
    m.print_network()
    m.train()
    bias_acc = meters()
    optimizer = torch.optim.SGD(m.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [int(0.5*num_epoch), int(0.75*num_epoch),
                                                                 int(0.9*num_epoch)], gamma=0.1)
    if use_gpu:
 #       os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible
#        if len(gpu_id) > 1:
        m = torch.nn.DataParallel(m, device_ids=gpu_id)
        m.cuda()

    one_syn_per_batch = False
    syn_min_rate = None

    if batch_size > 1 and ngpu > 1 and adapt:
        one_syn_per_batch = True
        syn_min_rate = batch_size // ngpu
        assert syn_min_rate > 1, "For DA (adapt=True), the batch size must be at least the double of number of GPU"

    train_dataset = YCB_Dataset(ycb_data_path, imageset_path, syn_data_path=syn_data_path, target_h=o_h, target_w=o_w,
                      use_real_img=use_real_img, bg_path=bg_path, syn_range=syn_range, num_syn_images=num_syn_img,
                                data_cfg="data/data-YCB.cfg", kp_path=kp_path, use_bg_img=use_bg_img, one_syn_per_batch = one_syn_per_batch, batch_size = syn_min_rate)
    median_balancing_weight = train_dataset.weight_cross_entropy.cuda() if use_gpu \
        else train_dataset.weight_cross_entropy

    print('training on %d images'%len(train_dataset))

    # for multiflow, need to keep track of the training progress
    m.module.coreModel.total_training_samples = seen + num_epoch * len(train_dataset)
    print('total training samples:', m.module.coreModel.total_training_samples)
    m.module.coreModel.seen = seen


    if gen_kp_gt:
        train_dataset.gen_kp_gt(for_syn=True, for_real=False)

    # Loss configurations

    # use balancing weights for crossentropy log (used in Hu. Segmentation-driven-pose, not used here)
    #seg_loss = nn.CrossEntropyLoss(weight=median_balancing_weight)

    seg_loss = nn.CrossEntropyLoss()
    seg_loss_factor = 1 # 1

    pos_loss = nn.L1Loss()
    pos_loss_factor = 2.6 #2,6

    conf_loss = nn.L1Loss()
    conf_loss_factor = 0.8 #0.8

    disc_loss = nn.CrossEntropyLoss()
    disc_loss_factor = 1

    seg_disc_loss = nn.CrossEntropyLoss()
    seg_disc_loss_factor = 1

    pos_disc_loss = nn.CrossEntropyLoss()
    pos_disc_loss_factor = 1


    # split into train and val
    train_db, val_db = torch.utils.data.random_split(train_dataset, [len(train_dataset)-2000, 2000])

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, # not use validation now
                                               batch_size=batch_size, num_workers=num_workers,
                                               shuffle=True, drop_last = True)
    val_loader = torch.utils.data.DataLoader(dataset=val_db,
                                               batch_size=batch_size,num_workers=num_workers,
                                               shuffle=True)
    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epoch):
        i=-1
        for images, seg_label, kp_gt_x, kp_gt_y, mask_front, domains in tqdm(train_loader):
            i += 1

            if use_gpu:
                images = images.cuda()
                seg_label = seg_label.cuda()
                kp_gt_x = kp_gt_x.cuda()
                kp_gt_y = kp_gt_y.cuda()
                mask_front = mask_front.cuda()
                domains = domains.cuda()


            d = domains[:, 0, 0].view(-1)
            zero_source = d.bool().all()
            domains = domains.view(-1)


            # if adapt=True, skip the batch if it contains zero source (synthetic) samples
            if adapt and zero_source:
                continue

            # forward pass
            output = m(images, adapt=adapt, domains=d)

            # discriminator
            pred_domains = output[2]
            seg_pred_domains = output[3]
            pos_pred_domains = output[4]
            l_disc = disc_loss(pred_domains, domains)

            l_seg_disc = seg_disc_loss(seg_pred_domains, d)
            l_pos_disc = pos_disc_loss(pos_pred_domains, d)



            if adapt:

                seg_label = source_only(seg_label, d)

            # segmentation
            pred_seg = output[0] # (BxOHxOW,C)
            seg_label = seg_label.view(-1)
            l_seg = seg_loss(pred_seg, seg_label)

            # regression
            mask_front = mask_front.repeat(number_point,1, 1, 1).permute(1,2,3,0).contiguous() # (B,OH,OW,NV)
            if adapt:
                mask_front = source_only(mask_front, d)
                kp_gt_x = source_only(kp_gt_x, d)
                kp_gt_y = source_only(kp_gt_y, d) 
            pred_x = output[1][0] * mask_front # (B,OH,OW,NV)
            pred_y = output[1][1] * mask_front
            kp_gt_x = kp_gt_x.float() * mask_front
            kp_gt_y = kp_gt_y.float() * mask_front
            l_pos = pos_loss(pred_x, kp_gt_x) + pos_loss(pred_y, kp_gt_y)

            # confidence
            conf = output[1][2] * mask_front # (B,OH,OW,NV)
            bias = torch.sqrt((pred_y-kp_gt_y)**2 + (pred_x-kp_gt_x)**2)
            conf_target = torch.exp(-modulating_factor * bias) * mask_front
            conf_target = conf_target.detach()
            l_conf = conf_loss(conf, conf_target)

            # combine all losses
            all_loss = l_seg * seg_loss_factor + l_pos * pos_loss_factor + l_conf * conf_loss_factor
            if adapt:
                all_loss += l_disc * disc_loss_factor + l_seg_disc * seg_disc_loss_factor + l_pos_disc * pos_disc_loss_factor

            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

            # gradient debug
            avggrad, avgdata = network_grad_ratio(m)
            print('avg gradiant ratio: %f, %f, %f' % (avggrad, avgdata, avggrad/avgdata))


            _, binary_domains = torch.max(pred_domains, 1)
            n_target_pred = binary_domains.float().sum()/(76*76)
            correct = (binary_domains == domains).float().sum()
            total = domains.size(0)
            acc = correct/total * 100

            _, seg_binary_domains = torch.max(seg_pred_domains, 1)
            correct = (seg_binary_domains == d).float().sum()
            total = d.size(0)
            seg_disc_acc = correct/total * 100


            _, pos_binary_domains = torch.max(pos_pred_domains, 1)
            correct = (pos_binary_domains == d).float().sum()
            total = d.size(0)
            pos_disc_acc = correct/total * 100

            def set_disc(require_grad = True, first_disc_layer = 126, last_disc_layer=139):
                for name, param in m.named_parameters():
                    for layer_i in range(first_disc_layer, last_disc_layer+1):
                        if "model." + str(layer_i) in name:
                            param.requires_grad = require_grad

            if (i + 1) % 20 == 0 and not zero_source:
                # compute pixel-wise bias to measure training accuracy
                bias_acc.update(abs(pnz((pred_x - kp_gt_x).cpu()).mean()*i_w))

                print('Epoch [{}/{}], Step [{}/{}]: \n seg loss: {:.4f}, pos loss: {:.4f}, conf loss: {:.4f}, pixel-wise bias:{:.4f} '
                      'disc loss: {:.4f}, disc acc: {:.4f} '
                      'disc seg loss: {:.4f}, disc seg acc: {:.4f} '
                      'disc pos loss: {:.4f}, disc pos acc: {:.4f} '
                      .format(epoch + 1, num_epoch, i + 1, total_step, l_seg.item(), l_pos.item(), l_conf.item(), bias_acc.value,
                             l_disc.item(), acc.item(),
                             l_seg_disc.item(), seg_disc_acc.item(),
                             l_pos_disc.item(), pos_disc_acc.item(),
                     ))

                writer.add_scalar('seg_loss', l_seg.item(), epoch*total_step+i)
                writer.add_scalar('pos loss', l_pos.item(), epoch*total_step+i)
                writer.add_scalar('conf_loss', l_conf.item(), epoch*total_step+i)
                writer.add_scalar('pixel_wise bias', bias_acc.value, epoch*total_step+i)

                writer.add_scalar('disc_loss', l_disc.item(), epoch*total_step+i)
                writer.add_scalar('disc_acc', acc.item(), epoch*total_step+i)

                writer.add_scalar('seg_disc_loss', l_seg_disc.item(), epoch*total_step+i)
                writer.add_scalar('seg_disc_acc', seg_disc_acc.item(), epoch*total_step+i)

                writer.add_scalar('pos_disc_loss', l_pos_disc.item(), epoch*total_step+i)
                writer.add_scalar('pos_disc_acc', pos_disc_acc.item(), epoch*total_step+i)

        bias_acc._reset()
        scheduler.step()

        # save weights
        if (epoch+1) % save_interval == 0:
            print("save weights to: ", weight_path(epoch))
            m.module.save_weights(weight_path(epoch))

    m.module.save_weights(weight_path(epoch))
    writer.close()
Example #2
0
def train(cfg_path):
    # network initialization
    data_options = read_data_cfg(cfg_path)
    model = SegPoseNet(data_options, is_train=True)

    # load pretained weights
    if pretrained_weights_path is not None:
        model.load_weights(pretrained_weights_path)
        print('Darknet weights loaded from ', pretrained_weights_path)

    # get input/output dimensions
    img_h = model.height
    img_w = model.width
    out_h = model.output_h
    out_w = model.output_w

    # print network graph
    model.print_network()

    model.train()

    bias_acc = meters()

    # optimizer initialization
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=initial_lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        [int(0.5 * num_epoch),
         int(0.75 * num_epoch),
         int(0.9 * num_epoch)],
        gamma=0.1)

    # device selection
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # dataset initialization
    train_dataset = YCBDataset(ycb_data_path,
                               imageset_path,
                               syn_data_path=syn_data_path,
                               target_h=out_h,
                               target_w=out_w,
                               use_real_img=use_real_img,
                               bg_path=bg_path,
                               num_syn_images=num_syn_img,
                               data_cfg=data_cfg,
                               kp_path=kp_path)
    if not os.path.isfile("data/balancing_weight.pkl"):
        train_dataset.gen_balancing_weight()
    train_dataset.set_balancing_weight()
    train_dataset.gen_kp_gt()
    median_balancing_weight = train_dataset.weight_cross_entropy.to(device)

    print('training on %d images' % len(train_dataset))
    if gen_kp_gt:
        train_dataset.gen_kp_gt()

    # loss configurations
    seg_loss = FocalLoss(alpha=1.0,
                         gamma=2.0,
                         weights=median_balancing_weight,
                         reduce=True)
    pos_loss = nn.L1Loss()
    pos_loss_factor = 1.5
    conf_loss = nn.L1Loss()
    conf_loss_factor = 1.0

    # train/val split
    train_db, val_db = torch.utils.data.random_split(
        train_dataset, [len(train_dataset) - 2000, 2000])

    train_loader = torch.utils.data.DataLoader(dataset=train_db,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_db,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             shuffle=True)
    # train model
    total_step = len(train_loader)
    # loop over number of epochs
    for epoch in range(num_epoch):
        i = 0
        for images, seg_label, kp_gt_x, kp_gt_y, mask_front in tqdm(
                train_loader):
            i += 1
            # data to device
            images = images.to(device)
            seg_label = seg_label.to(device)
            kp_gt_x = kp_gt_x.to(device)
            kp_gt_y = kp_gt_y.to(device)
            mask_front = mask_front.to(device)

            # forward pass
            output = model(images)

            # segmentation
            pred_seg = output[0]  # (B,OH,OW,C)
            seg_label = seg_label.view(-1)

            l_seg = seg_loss(pred_seg, seg_label)

            # regression
            mask_front = mask_front.repeat(number_point, 1, 1, 1).permute(
                1, 2, 3, 0).contiguous()  # (B,OH,OW,NV)
            pred_x = output[1][0] * mask_front  # (B,OH,OW,NV)
            pred_y = output[1][1] * mask_front
            kp_gt_x = kp_gt_x.float() * mask_front
            kp_gt_y = kp_gt_y.float() * mask_front
            l_pos = pos_loss(pred_x, kp_gt_x) + pos_loss(pred_y, kp_gt_y)

            # confidence
            conf = output[1][2] * mask_front  # (B,OH,OW,NV)
            bias = torch.sqrt((pred_y - kp_gt_y)**2 + (pred_x - kp_gt_x)**2)
            conf_target = torch.exp(-modulating_factor * bias) * mask_front
            conf_target = conf_target.detach()
            l_conf = conf_loss(conf, conf_target)

            # combine all losses
            all_loss = l_seg + l_pos * pos_loss_factor + l_conf * conf_loss_factor
            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                # compute pixel-wise bias to measure training accuracy
                bias_acc.update(
                    abs(pnz((pred_x - kp_gt_x).cpu()).mean() * img_w))
                # write losses to tensorboard writer
                writer.add_scalar('seg_loss', l_seg.item(),
                                  epoch * total_step + i)
                writer.add_scalar('pos loss', l_pos.item(),
                                  epoch * total_step + i)
                writer.add_scalar('conf_loss', l_conf.item(),
                                  epoch * total_step + i)
                writer.add_scalar('pixel_wise bias', bias_acc.value,
                                  epoch * total_step + i)

        # reset pixel_wise bias meter
        bias_acc._reset()
        # LR scheduler step
        scheduler.step()

        # model validation
        with torch.no_grad():
            total_seg_loss = 0
            total_pos_loss = 0
            total_conf_loss = 0
            total_loss = 0
            viz_imgs = []
            j = 0
            for images, seg_label, kp_gt_x, kp_gt_y, mask_front in tqdm(
                    val_loader):
                j += 1
                # data to device
                images = images.to(device)
                seg_label = seg_label.to(device)
                kp_gt_x = kp_gt_x.to(device)
                kp_gt_y = kp_gt_y.to(device)
                mask_front = mask_front.to(device)
                # forward pass
                output = model(images)
                # segmentation
                pred_seg = output[0]
                seg_label = seg_label.view(-1)
                l_seg = seg_loss(pred_seg, seg_label)
                # regression
                mask_front = mask_front.repeat(number_point, 1, 1,
                                               1).permute(1, 2, 3,
                                                          0).contiguous()
                pred_x = output[1][0] * mask_front
                pred_y = output[1][1] * mask_front
                kp_gt_x = kp_gt_x.float() * mask_front
                kp_gt_y = kp_gt_y.float() * mask_front
                l_pos = pos_loss(pred_x, kp_gt_x) + pos_loss(pred_y, kp_gt_y)
                # confidence
                conf = output[1][2] * mask_front
                bias = torch.sqrt((pred_y - kp_gt_y)**2 +
                                  (pred_x - kp_gt_x)**2)
                conf_target = torch.exp(-modulating_factor * bias) * mask_front
                conf_target = conf_target.detach()
                l_conf = conf_loss(conf, conf_target)
                # combine all losses
                all_loss = l_seg + l_pos * pos_loss_factor + l_conf * conf_loss_factor
                total_seg_loss += l_seg.item()
                total_pos_loss += l_pos.item()
                total_conf_loss += l_conf.item()
                total_loss += all_loss.item()
                # data visualization
                if (j + 1) % 100 == 0:
                    model.eval()  # change network to eval mode
                    output = model(images)  # perform inference
                    pred_pose = fusion(output, img_width, img_height,
                                       intrinsics, conf_thresh, batch_idx,
                                       best_cnt)  # output fusion
                    image = np.uint8(
                        convert2cpu(
                            images[batch_idx]).detach().numpy().transpose(
                                1, 2, 0) * 255.0)  # get image
                    image = resize(image,
                                   (img_height, img_width))  # resize image
                    viz_img = visualize_predictions(pred_pose, image, vertices,
                                                    intrinsics).transpose(
                                                        2, 0,
                                                        1)  # visualize poses
                    viz_imgs.append(viz_img)  # append to visualizations
                    model.train()  # change network to train mode
            # print total validation losses
            print(
                'Epoch [{}/{}], Validation Loss: \n seg loss: {:.4f}, pos loss: {:.4f}, conf loss: {:.4f}, total loss: {:.4f}'
                .format(epoch + 1, num_epoch, total_seg_loss, total_pos_loss,
                        total_conf_loss, total_loss))
            # write visualizations to tensorboard writer
            viz_data = np.stack(viz_imgs, axis=0)
            writer.add_images('pose_viz',
                              torch.from_numpy(viz_data),
                              global_step=epoch + 1)

        # save model checkpoint per epoch
        model.save_weights(os.path.join(checkpoints_dir, f'ckpt_{epoch}.pth'))

    # save final model checkpoint
    model.save_weights(os.path.join(checkpoints_dir, 'ckpt_final.pth'))
    writer.close()
def train(data_cfg):
    data_options = read_data_cfg(data_cfg)
    m = SegPoseNet(data_options)
    if load_weight_from_path is not None:
        m.load_weights(load_weight_from_path)
        print("Load weights from ", load_weight_from_path)
    i_h = m.height
    i_w = m.width
    o_h = m.output_h
    o_w = m.output_w
    # m.print_network()
    m.train()
    bias_acc = meters()
    optimizer = torch.optim.SGD(m.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [int(0.5*num_epoch), int(0.75*num_epoch),
                                                                 int(0.9*num_epoch)], gamma=0.1)
    if use_gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible
        m = torch.nn.DataParallel(m, device_ids=gpu_id)
        m.cuda()

    train_dataset = YCB_Dataset(ycb_data_path, imageset_path, syn_data_path=syn_data_path, target_h=o_h, target_w=o_w,
                      use_real_img=use_real_img, bg_path=bg_path, num_syn_images=num_syn_img,
                                data_cfg="data/data-YCB.cfg", kp_path=kp_path)
    median_balancing_weight = train_dataset.weight_cross_entropy.cuda() if use_gpu \
        else train_dataset.weight_cross_entropy

    print('training on %d images'%len(train_dataset))
    if gen_kp_gt:
        train_dataset.gen_kp_gt()

    # Loss configurations
    seg_loss = nn.CrossEntropyLoss(weight=median_balancing_weight)
    pos_loss = nn.L1Loss()
    pos_loss_factor = 1.3  # 0.02 in original paper
    conf_loss = nn.L1Loss()
    conf_loss_factor = 0.8  # 0.02 in original paper

    # split into train and val
    train_db, val_db = torch.utils.data.random_split(train_dataset, [len(train_dataset)-2000, 2000])


    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, # not use validation now
                                               batch_size=batch_size, num_workers=num_workers,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_db,
                                               batch_size=batch_size,num_workers=num_workers,
                                               shuffle=True)
    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epoch):
        i=-1
        for images, seg_label, kp_gt_x, kp_gt_y, mask_front in tqdm(train_loader):
            i += 1
            if use_gpu:
                images = images.cuda()
                seg_label = seg_label.cuda()
                kp_gt_x = kp_gt_x.cuda()
                kp_gt_y = kp_gt_y.cuda()
                mask_front = mask_front.cuda()

            # forward pass
            output = m(images)

            # segmentation
            pred_seg = output[0] # (BxOHxOW,C)
            seg_label = seg_label.view(-1)

            l_seg =seg_loss(pred_seg, seg_label)

            # regression
            mask_front = mask_front.repeat(number_point,1, 1, 1).permute(1,2,3,0).contiguous() # (B,OH,OW,NV)
            pred_x = output[1][0] * mask_front # (B,OH,OW,NV)
            pred_y = output[1][1] * mask_front
            kp_gt_x = kp_gt_x.float() * mask_front
            kp_gt_y = kp_gt_y.float() * mask_front
            l_pos = pos_loss(pred_x, kp_gt_x) + pos_loss(pred_y, kp_gt_y)

            # confidence
            conf = output[1][2] * mask_front # (B,OH,OW,NV)
            bias = torch.sqrt((pred_y-kp_gt_y)**2 + (pred_x-kp_gt_x)**2)
            conf_target = torch.exp(-modulating_factor * bias) * mask_front
            conf_target = conf_target.detach()
            l_conf = conf_loss(conf, conf_target)

            # combine all losses
            all_loss = l_seg + l_pos * pos_loss_factor + l_conf * conf_loss_factor
            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                # compute pixel-wise bias to measure training accuracy
                bias_acc.update(abs(pnz((pred_x - kp_gt_x).cpu()).mean()*i_w))
                print('Epoch [{}/{}], Step [{}/{}]: \n seg loss: {:.4f}, pos loss: {:.4f}, conf loss: {:.4f}, '
                      'Pixel-wise bias:{:.4f}'
                      .format(epoch + 1, num_epoch, i + 1, total_step, l_seg.item(), l_pos.item(),
                              l_conf.item(), bias_acc.value))

                writer.add_scalar('seg_loss', l_seg.item(), epoch*total_step+i)
                writer.add_scalar('pos loss', l_pos.item(), epoch*total_step+i)
                writer.add_scalar('conf_loss', l_conf.item(), epoch*total_step+i)
                writer.add_scalar('pixel_wise bias', bias_acc.value, epoch*total_step+i)
        bias_acc._reset()
        scheduler.step()
        if epoch % 5 == 1:
            m.module.save_weights(weight_path)
    m.module.save_weights(weight_path)
    writer.close()