Esempio n. 1
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('reg_seg_heatmap_v3')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
        train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_pointnet2_reg_seg_heatmap_stepsize.py', str(exp_dir))

    #network = model.get_model(out_channel, normal_channel=args.use_normals)
    network = model.get_model(out_channel)
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()

    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_heatmap_error = 99.9
    best_step_size_error = 99.9
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        train_rot_error = []
        train_xyz_error = []
        train_heatmap_error = []
        train_step_size_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            optimizer.zero_grad()

            points = data[parameter.pcd_key].numpy()
            points = provider.normalize_data(points)
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :,
                                                                         0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)
            heatmap_target = data[parameter.heatmap_key]
            segmentation_target = data[parameter.segmentation_key]
            #print('heatmap size', heatmap_target.size())
            #print('segmentation', segmentation_target.size())
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]
            unit_delta_xyz = data[parameter.unit_delta_xyz_key]
            step_size = data[parameter.step_size_key]

            if not args.use_cpu:
                points = points.cuda()
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                heatmap_target = heatmap_target.cuda()
                unit_delta_xyz = unit_delta_xyz.cuda()
                step_size = step_size.cuda()

            heatmap_pred, action_pred, step_size_pred = network(points)
            # action control
            delta_rot_pred_6d = action_pred[:, 0:6]
            delta_rot_pred = compute_rotation_matrix_from_ortho6d(
                delta_rot_pred_6d, args.use_cpu)  # batch*3*3
            delta_xyz_pred = action_pred[:, 6:9].view(-1, 3)  # batch*3

            # loss computation
            loss_heatmap = criterion_rmse(heatmap_pred, heatmap_target)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_t = (1 - criterion_cos(delta_xyz_pred, unit_delta_xyz)).mean()
            loss_step_size = criterion_bce(step_size_pred, step_size)
            loss = loss_r + loss_t + loss_heatmap + loss_step_size
            loss.backward()
            optimizer.step()
            global_step += 1

            train_rot_error.append(loss_r.item())
            train_xyz_error.append(loss_t.item())
            train_heatmap_error.append(loss_heatmap.item())
            train_step_size_error.append(loss_step_size.item())

        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_heatmap_error = sum(train_heatmap_error) / len(
            train_heatmap_error)
        train_step_size_error = sum(train_step_size_error) / len(
            train_step_size_error)
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Heatmap Error: %f' % train_xyz_error)
        log_string('Train Step size Error: %f' % train_step_size_error)

        with torch.no_grad():
            rot_error, xyz_error, heatmap_error, step_size_error = test(
                network.eval(), validDataLoader, out_channel, criterion_rmse,
                criterion_cos, criterion_bce)

            log_string(
                'Test Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (rot_error, xyz_error, heatmap_error, step_size_error))
            log_string(
                'Best Rotation Error: %f, Translation Error: %f, Heatmap Error: %f, Step size Error: %f'
                % (best_rot_error, best_xyz_error, best_heatmap_error,
                   best_step_size_error))

            if (rot_error + xyz_error + heatmap_error + step_size_error) < (
                    best_rot_error + best_xyz_error + best_heatmap_error +
                    best_step_size_error):
                best_rot_error = rot_error
                best_xyz_error = xyz_error
                best_heatmap_error = heatmap_error
                best_step_size_error = step_size_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'rot_error': rot_error,
                    'xyz_error': xyz_error,
                    'heatmap_error': heatmap_error,
                    'step_size_error': step_size_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('dsae')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    # Construct the dataset
    train_dataset, train_config = construct_dataset(is_train=True)
    # Random split
    train_set_size = int(len(train_dataset) * 0.8)
    valid_set_size = len(train_dataset) - train_set_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
        train_dataset, [train_set_size, valid_set_size])
    # And the dataloader
    trainDataLoader = DataLoader(dataset=train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)
    validDataLoader = DataLoader(dataset=valid_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4)
    '''MODEL LOADING'''
    out_channel = args.out_channel
    model = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))
    shutil.copy('./train_dsae.py', str(exp_dir))

    network = model.get_model()
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss()
    criterion_kptof = OFLoss()
    network.apply(inplace_relu)

    if not args.use_cpu:
        network = network.cuda()
        criterion_rmse = criterion_rmse.cuda()
        criterion_cos = criterion_cos.cuda()
        criterion_bce = criterion_bce.cuda()
        criterion_kptof = criterion_kptof.cuda()
    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        network.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=0.01,
                                    momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=10,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_rot_error = 99.9
    best_xyz_error = 99.9
    best_recon_error = 99.9
    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        train_xyz_error = []
        train_rot_error = []
        train_recon_error = []
        network = network.train()

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            optimizer.zero_grad()
            rgbd = data[parameter.rgbd_image_key]
            rgb_pair = data[parameter.rgb_pair_key]
            depth_pair = data[parameter.depth_pair_key]
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]

            if not args.use_cpu:
                delta_rot = delta_rot.cuda()
                delta_xyz = delta_xyz.cuda()
                rgbd = rgbd.cuda()
                rgb_pair = rgb_pair.cuda()
                depth_pair = depth_pair.cuda()
            rgb = rgbd[:, :3, :, :]
            depth = rgbd[:, 3:, :]
            # rgb (B,3,H,W) rgb_pair (B,6,H,W)
            # depth (B,1,H,W) depth_pair (B,2,H,W)
            delta_xyz_pred, delta_rot_pred, depth_pred = network(rgb_pair)
            # loss computation
            loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz)
                      ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            loss_recon = criterion_rmse(depth_pred, depth_pair)
            loss = loss_t + loss_r + loss_recon
            loss.backward()
            optimizer.step()
            global_step += 1

            train_xyz_error.append(loss_t.item())
            train_rot_error.append(loss_r.item())
            train_recon_error.append(loss_recon.item())

        train_xyz_error = sum(train_xyz_error) / len(train_xyz_error)
        train_rot_error = sum(train_rot_error) / len(train_rot_error)
        train_recon_error = sum(train_recon_error) / len(train_recon_error)

        log_string('Train Translation Error: %f' % train_xyz_error)
        log_string('Train Rotation Error: %f' % train_rot_error)
        log_string('Train Reconstruction Error: %f' % train_recon_error)
        with torch.no_grad():
            xyz_error, rot_error, recon_error = test(
                network.eval(), validDataLoader, out_channel, criterion_rmse,
                criterion_cos, criterion_bce, criterion_kptof)

            log_string(
                'Test Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f'
                % (xyz_error, rot_error, recon_error))
            log_string(
                'Best Translation Error: %f, Rotation Error: %f, Reconstruction Error: %f'
                % (best_xyz_error, best_rot_error, best_recon_error))

            if (xyz_error + rot_error + recon_error) < (
                    best_xyz_error + best_rot_error + best_recon_error):
                best_xyz_error = xyz_error
                best_rot_error = rot_error
                best_recon_error = recon_error
                best_epoch = epoch + 1
                logger.info('Save model...')
                savepath = str(checkpoints_dir) + '/best_model_e_' + str(
                    best_epoch) + '.pth'
                log_string('Saving at %s' % savepath)
                state = {
                    'epoch': best_epoch,
                    'xyz_error': xyz_error,
                    'rot_error': rot_error,
                    'recon_error': recon_error,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }
                torch.save(state, savepath)
            global_epoch += 1

    logger.info('End of training...')