Exemple #1
0
def main():
    args = parse_args()
    logging.info(args)
    demo = BaseDemo(args)
    if args.train:
        demo.train()
    if args.test:
        demo.test()
    if args.test_gt:
        demo.test_gt()
Exemple #2
0
def main():
    args = parse_args()
    logging.info(args)
    demo = Demo(args)
    if args.train:
        demo.train_unsupervised()
    if args.test:
        demo.test_unsupervised()
    if args.test_gt:
        demo.test_gt_unsupervised()
Exemple #3
0
def main():
    logging.info('----------------------------------------------------------------')
    logging.info('****************************************************************')
    args = parse_args()
    logging.info(args)

    if args.data == 'kitti':
        data = KittiData(args.data_path, args.train_proportion, args.test_proportion)
        train_meta = data.train_meta
        train_data = KittiDataLoader(train_meta, args.batch_size, args.image_heights,
                                     args.image_widths, args.output_heights, args.output_widths,
                                     args.num_scale, data_augment=False, shuffle=False)
        test_meta = data.test_meta
        test_data = KittiDataLoader(test_meta, args.batch_size, args.image_heights,
                                    args.image_widths, args.output_heights, args.output_widths,
                                    args.num_scale)
    else:
        print('Data not implemented yet')
        return

    if args.model == 'base':
        model = BaseNet(args.image_channel, args.num_class)
    elif args.model == 'base_3d':
        model = Base3DNet(args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'base_2stream':
        model = Base2StreamNet(args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'seg_3d':
        model = Seg3DNet(args.image_channel, args.depth_channel, args.num_class)
    else:
        print('Model not implemented yet')
        return

    interface = DetectInterface(data, train_data, test_data, model, args.learning_rate,
                                args.train_epoch, args.test_interval, args.test_iteration,
                                args.save_interval, args.init_model_path, args.save_model_path,
                                args.tensorboard_path)

    if args.train:
        logging.info('Experiment: %s, training', args.exp_name)
        interface.train()
    elif args.test:
        logging.info('Experiment: %s, testing all', args.exp_name)
        interface.test_all()
    elif args.visualize:
        logging.info('Experiment: %s, visualizing', args.exp_name)
        interface.visualize(args.image_name, args.depth_name, args.flow_name, args.box_name,
                            args.figure_path)
    elif args.visualize_all:
        logging.info('Experiment: %s, visualizing all', args.exp_name)
        interface.visualize_all(args.image_list, args.figure_path)
    else:
        print('Unknown command')
        return
Exemple #4
0
def main():
    logging.info(
        '----------------------------------------------------------------')
    logging.info(
        '****************************************************************')
    args = parse_args()
    logging.info(args)

    if args.data == 'vdrift':
        data = VDriftData(args.data_path, args.batch_size, args.image_heights,
                          args.image_widths, args.output_height,
                          args.output_width, args.num_scale,
                          args.train_proportion, args.test_proportion,
                          args.show_statistics)
    else:
        print('Data not implemented yet')
        return

    if args.model == 'base':
        model = BaseNet(args.image_channel, args.num_class)
    elif args.model == 'base_3d':
        model = Base3DNet(args.image_channel, args.depth_channel,
                          args.num_class)
    elif args.model == 'base_2stream':
        model = Base2StreamNet(args.image_channel, args.depth_channel,
                               args.num_class)
    else:
        print('Model not implemented yet')
        return

    interface = SegmentInterface(data, model, args.learning_rate,
                                 args.train_iteration, args.test_iteration,
                                 args.test_interval, args.save_interval,
                                 args.init_model_path, args.save_model_path,
                                 args.tensorboard_path)

    if args.train:
        logging.info('Experiment: %s, training', args.exp_name)
        interface.train()
    elif args.test:
        logging.info('Experiment: %s, testing all', args.exp_name)
        interface.test_all()
    elif args.visualize:
        logging.info('Experiment: %s, visualizing', args.exp_name)
        interface.visualize(args.image_name, args.depth_name, args.flow_x_name,
                            args.flow_y_name, args.seg_name, args.figure_path)
    elif args.visualize_all:
        logging.info('Experiment: %s, visualizing all', args.exp_name)
        interface.visualize_all(args.image_list, args.figure_path)
    else:
        print('Unknown command')
        return
Exemple #5
0
def main():
    args = parse_args()
    logging.info(args)
    if args.data == 'mlt':
        data = MLTData(args.batch_size, args.image_size, args.direction_type,
                       args.train_proportion, args.test_proportion,
                       args.show_statistics)
    elif args.data == 'viper':
        print('Not Implemented Yet')
        return

    data_test = DataTest(data)
    data_test.test()
Exemple #6
0
def main():
    args = parse_args()
    logging.info(args)

    if args.data == 'kitti':
        data = KittiData(args.data_path, args.train_proportion,
                         args.test_proportion)
        train_meta = data.train_meta
        train_data = KittiDataLoader(train_meta,
                                     args.batch_size,
                                     args.image_heights,
                                     args.image_widths,
                                     args.output_heights,
                                     args.output_widths,
                                     args.num_scale,
                                     data_augment=True,
                                     shuffle=True)
        data_test = DataTest(train_data)
        data_test.test()

        test_meta = data.test_meta
        test_data = KittiDataLoader(test_meta, args.batch_size,
                                    args.image_heights, args.image_widths,
                                    args.output_heights, args.output_widths,
                                    args.num_scale)
        data_test = DataTest(test_data)
        data_test.test()

        meta = {
            'image': [args.image_name],
            'depth': [args.depth_name],
            'flow': [args.flow_name],
            'box': [read_box(args.box_name)]
        }
        data = KittiDataLoader(meta, args.batch_size, args.image_heights,
                               args.image_widths, args.output_heights,
                               args.output_widths, args.num_scale)
        data_test = DataTest(data)
        data_test.test()
    else:
        print('Not Implemented Yet')
        return
Exemple #7
0
def main():
    args = parse_args()
    logging.info(args)
    # args.data_path = '/media/yi/DATA/data-orig/kitti/training'
    # args.image_name = '/media/yi/DATA/data-orig/kitti/training/image_2/007480.png'
    # args.depth_name = '/media/yi/DATA/data-orig/kitti/training/disp_unsup/007480.png'
    # args.flow_name = '/media/yi/DATA/data-orig/kitti/training/flow_unsup/007480.png'
    # args.box_name = '/media/yi/DATA/data-orig/kitti/training/label_2/007480.txt'

    if args.data == 'kitti':
        data = KittiData(args.data_path, args.batch_size, args.image_heights, args.image_widths,
                         args.output_heights, args.output_widths, args.num_scale,
                         args.train_proportion, args.test_proportion, args.show_statistics)
    else:
        print('Not Implemented Yet')
        return

    data_test = DataTest(data)
    data_test.test()
    data_test.test_one_image(args.image_name, args.depth_name, args.flow_name, args.box_name)
Exemple #8
0
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = MpiiData(args)
    im = data.get_next_batch(data.train_images)
    data.display(im)
Exemple #9
0
def main():
    logging.info('----------------------------------------------------------------')
    logging.info('****************************************************************')
    args = parse_args()
    logging.info(args)

    if args.data == 'mlt':
        data = MLTData(args.batch_size, args.image_size, args.direction_type, args.train_proportion,
                       args.test_proportion, args.show_statistics)
    elif args.data == 'viper':
        print('Not Implemented Yet')
        return

    if args.model == 'base':
        model = BaseNet(args.image_channel, args.num_class)
    elif args.model == 'base_direct':
        model = BaseDirectNet(args.image_channel, args.direction_dim, args.num_class)
    elif args.model == 'base_3d':
        model = Base3DNet(args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'base_direct_3d':
        model = BaseDirect3DNet(args.image_channel, args.depth_channel, args.direction_dim, args.num_class)
    elif args.model == 'base_2stream':
        model = Base2StreamNet(args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'base_direct_2stream':
        model = BaseDirect2StreamNet(args.image_channel, args.depth_channel, args.direction_dim, args.num_class)
    elif args.model == 'hard_gt_attn':
        model = HardGtAttnNet(args.image_size[0], args.image_channel, args.num_class)
    elif args.model == 'hard_direct':
        model = HardDirectNet(args.image_size[0], args.image_channel, args.direction_dim, args.num_class)
    elif args.model == 'hard_gt_attn_3d':
        model = HardGtAttn3DNet(args.image_size[0], args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'hard_gt_attn_2stream':
        model = HardGtAttn2StreamNet(args.image_size[0], args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'soft_attn':
        model = SoftAttnNet(args.attention_size, args.image_channel, args.num_class)
    elif args.model == 'soft_direct':
        model = SoftDirectNet(args.attention_size, args.image_channel, args.direction_dim, args.num_class)
    elif args.model == 'soft_comb':
        model = SoftCombNet(args.attention_size, args.image_channel, args.direction_dim, args.num_class)
    elif args.model == 'soft_attn_3d':
        model = SoftAttn3DNet(args.attention_size, args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'soft_direct_3d':
        model = SoftDirect3DNet(args.attention_size, args.image_channel, args.depth_channel, args.direction_dim, args.num_class)
    elif args.model == 'soft_comb_3d':
        model = SoftComb3DNet(args.attention_size, args.image_channel, args.depth_channel, args.direction_dim, args.num_class)
    elif args.model == 'soft_attn_2stream':
        model = SoftAttn2StreamNet(args.attention_size, args.image_channel, args.depth_channel, args.num_class)
    elif args.model == 'soft_direct_2stream':
        model = SoftDirect2StreamNet(args.attention_size, args.image_channel, args.depth_channel, args.direction_dim, args.num_class)
    elif args.model == 'soft_comb_2stream':
        model = SoftComb2StreamNet(args.attention_size, args.image_channel, args.depth_channel, args.direction_dim, args.num_class)

    if args.attention_type == 'soft':
        interface = SoftAttnInterface(data, model, args.learning_rate, args.train_iteration, args.test_iteration,
                                      args.test_interval, args.save_interval, args.init_model_path,
                                      args.save_model_path, args.tensorboard_path)
    elif args.attention_type == 'hard':
        interface = HardAttnInterface(data, model, args.learning_rate, args.train_iteration, args.test_iteration,
                                      args.test_interval, args.save_interval, args.init_model_path,
                                      args.save_model_path, args.tensorboard_path)

    if args.train:
        logging.info('Experiment: %s, training', args.exp_name)
        interface.train()
    elif args.test:
        logging.info('Experiment: %s, testing all', args.exp_name)
        interface.test_all()
    elif args.visualize:
        logging.info('Experiment: %s, visualizing', args.exp_name)
        interface.visualize(args.image_name, args.depth_name, args.box_name, args.figure_path)
    elif args.visualize_all:
        logging.info('Experiment: %s, visualizing all', args.exp_name)
        interface.visualize_all(args.image_list, args.figure_path)
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = BoxDataComplex(args)
    im, motion, motion_r, motion_label, motion_label_r, seg_layer = data.get_next_batch()
    data.display(im, motion, motion_r, seg_layer)
Exemple #11
0
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = Kitti128Sample(args)
    im = data.get_next_batch(data.test_images)
    data.display(im)
Exemple #12
0
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = ChairData(args)
    im, motion, motion_label, seg_layer = data.get_next_batch(data.train_images)
    data.display(im, motion, seg_layer)
def main():
    args = parse_args()
    logging.info(args)
    demo = BaseDemo(args)
    demo.compare()
Exemple #14
0
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = Robot64Data(args)
    im = data.get_next_batch(data.train_meta)
    data.display(im)
def unit_test():
    args = learning_args.parse_args()
    logging.info(args)
    data = MnistDataBidirect(args)
    im, motion, motion_r, motion_label, motion_label_r, seg_layer = data.get_next_batch(data.train_images)
    data.display(im, motion, motion_r, seg_layer)