def func_wrapper(*args, **kwargs):
     try:
         return func(*args, **kwargs)
     except:
         print('Wait until the dataprocesses to end')
         kill_processes(train_queue, train_processes)
         kill_processes(validation_queue, val_processes)
         raise
Beispiel #2
0
def train_net():
    '''Main training function'''
    # Set up the model and the solver
    NetClass = load_model(cfg.CONST.NETWORK_CLASS)

    # print('Network definition: \n')
    # print(inspect.getsource(NetClass.network_definition))
    net = NetClass()

    # Check that single view reconstruction net is not used for multi view
    # reconstruction.
    if net.is_x_tensor4 and cfg.CONST.N_VIEWS > 1:
        raise ValueError('Do not set the config.CONST.N_VIEWS > 1 when using' \
                         'single-view reconstruction network')

    # Prefetching data processes
    #
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    global train_queue, val_queue, train_processes, val_processes
    train_queue = Queue(cfg.QUEUE_SIZE)
    val_queue = Queue(cfg.QUEUE_SIZE)

    train_processes = make_data_processes(
        train_queue,
        category_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION),
        cfg.TRAIN.NUM_WORKER,
        repeat=True)
    val_processes = make_data_processes(
        val_queue,
        category_model_id_pair(dataset_portion=cfg.TEST.DATASET_PORTION),
        1,
        repeat=True,
        train=False)

    import torch.cuda
    if torch.cuda.is_available():
        net.cuda()

    # print the queue
    # print(train_queue)
    # print(val_queue)

    # Generate the solver
    solver = Solver(net)

    # Train the network
    solver.train(train_queue, val_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(val_queue, val_processes)
Beispiel #3
0
def test_lba_process():
    from multiprocessing import Queue
    from lib.config import cfg
    from lib.utils import open_pickle, print_sentences, get_json_path

    cfg.CONST.BATCH_SIZE = 8
    cfg.CONST.DATASET = 'shapenet'
    cfg.CONST.SYNTH_EMBEDDING = False

    caption_data = open_pickle(cfg.DIR.VAL_DATA_PATH)
    data_queue = Queue(3)
    json_path = get_json_path()

    data_process = LBADataProcess(data_queue, caption_data, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    category_list = caption_batch['category_list']
    model_list = caption_batch['model_list']

    for k, v in caption_batch.items():
        if isinstance(v, list):
            print('Key:', k)
            print('Value length:', len(v))
        elif isinstance(v, np.ndarray):
            print('Key:', k)
            print('Value shape:', v.shape)
        else:
            print('Other:', k)
    print('')

    for i in range(len(category_list)):
        print('---------- %03d ------------' % i)
        category = category_list[i]
        model_id = model_list[i]

        # Generate sentence
        for j in range(data_process.n_captions_per_model):
            caption_idx = data_process.n_captions_per_model * i + j
            caption = caption_batch['raw_embedding_batch'][caption_idx]
            # print('Caption:', caption)
            # print('Converted caption:')
            data_list = [{'raw_caption_embedding': caption}]
            print_sentences(json_path, data_list)
            print('Label:', caption_batch['caption_label_batch'][caption_idx])

        print('Category:', category)
        print('Model ID:', model_id)

    kill_processes(data_queue, [data_process])
def train_net():

    # Set up the model and the solver
    my_net = My_ResidualGRUNet()

    # Generate the solver
    solver = Solver(my_net)

    # Load the global variables
    global train_queue, validation_queue, train_processes, val_processes

    # Initialize the queues
    train_queue = Queue(
        15)  # maximum number of minibatches that can be put in a data queue)
    validation_queue = Queue(15)

    # Train on 80 percent of the data
    train_dataset_portion = [0, 0.8]

    # Validate on 20 percent of the data
    test_dataset_portion = [0.8, 1]

    # Establish the training procesesses
    train_processes = make_data_processes(
        train_queue,
        category_model_id_pair(dataset_portion=train_dataset_portion),
        1,
        repeat=True)

    # Establish the validation procesesses
    val_processes = make_data_processes(
        validation_queue,
        category_model_id_pair(dataset_portion=test_dataset_portion),
        1,
        repeat=True,
        train=False)

    # Train the network
    solver.train(train_queue, validation_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(validation_queue, val_processes)
Beispiel #5
0
def test_caption_process():
    from multiprocessing import Queue
    from lib.config import cfg
    from lib.utils import open_pickle, print_sentences

    cfg.CONST.DATASET = 'primitives'
    cfg.CONST.SYNTH_EMBEDDING = False

    asdf_captions = open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH)

    data_queue = Queue(3)

    data_process = CaptionDataProcess(data_queue, asdf_captions, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    captions_tensor, category_list, model_list = caption_batch

    assert captions_tensor.shape[0] == len(category_list)
    assert len(category_list) == len(model_list)

    for i in range(len(category_list)):
        print('---------- %03d ------------' % i)
        caption = captions_tensor[i]
        category = category_list[i]
        model_id = model_list[i]
        # print('Caption:', caption)
        # print('Converted caption:')

        # Generate sentence
        # data_list = [{'raw_caption_embedding': caption}]
        # print_sentences(json_path, data_list)

        print('Category:', category)
        # print('Model ID:', model_id)

    kill_processes(data_queue, [data_process])
def main():
    global opts
    opts.voxel_size = 32 # 4 x 32 x 32 x 32 

    ###########################################################
    # Dataset and Make Data Loading Process 
    ###########################################################
    inputs_dict, test_inputs_dict = get_inputs_dict(opts)

    # map category to label ('box-teal-h20-r20' -> 755)
    opts.category2label_dict = inputs_dict['class_labels']
    assert inputs_dict['vocab_size'] == test_inputs_dict['vocab_size']
    opts.vocab_size = inputs_dict['vocab_size']

    # Prefetching data processes 
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    # set up data queue and start enqueue
    np.random.seed(123) 
    test_data_process_for_class = models.get_data_process_pairs('LBA1', opts, is_training=False) 
    
    global test_queue, test_processes 
    test_queue = Queue(opts.queue_capacity) 
    opts.num_workers = 1 
    test_processes = make_data_processes(test_data_process_for_class, test_queue, test_inputs_dict, opts, repeat=False) 

    ###########################################################
    ## build network, loading pretrained model, shift to GPU 
    ###########################################################
    print('-------------building network--------------')
    network_class = models.load_model('LBA1')
    text_encoder = network_class['Text_Encoder'](opts.vocab_size, embedding_size=128, encoder_output_normalized=False) 
    shape_encoder = network_class['Shape_Encoder'](num_classes=opts.num_classes, encoder_output_normalized=False) 

    print('text encoder: ')
    print(text_encoder)
    print('shape encoder: ')
    print(shape_encoder)

    print('loading checkpoints....')
    if opts.pretrained_model != '':
        print('loading pretrained model from {0}'.format(opts.pretrained_model))
        checkpoint = torch.load(opts.pretrained_model)
        text_encoder.load_state_dict(checkpoint['text_encoder'])
        shape_encoder.load_state_dict(checkpoint['shape_encoder'])
    else: 
        assert ValueError('please input the path to pretrained model.')

    ###########################################################
    ## Training Criterion 
    ###########################################################
    criterion = {}
    opts.LBA_NO_LBA = False 
    if opts.LBA_NO_LBA is False: 
        # by default, we set visit loss weith to be 0.25
        LBA_loss = loss.LBA_Loss(lmbda=0.25, LBA_model_type=opts.LBA_model_type, batch_size=opts.batch_size)
        criterion['LBA_loss'] = LBA_loss

    # classificaiton loss 
    opts.LBA_Classificaiton = False 
    if opts.LBA_Classificaiton is True: 
        pass 

    # metric loss
    opts.LBA_Metric = True 
    if opts.LBA_Metric is True: 
        opts.rho = 1.0 # set opts.rho to be 1.0 for combining LBA_loss and Metric loss 
        Metric_loss = loss.Metric_Loss(opts, LBA_inverted_loss=True, LBA_normalized=False, LBA_max_norm=10.0)
        criterion['Metric_Loss'] = Metric_loss

    ## shift models to cuda 
    if opts.cuda:  
        print('shift model and criterion to GPU .. ')
        text_encoder = text_encoder.cuda() 
        shape_encoder = shape_encoder.cuda() 
        if opts.ngpu > 1:
            text_encoder = nn.DataParallel(text_encoder, device_ids=range(opts.ngpu)) 
            shape_encoder = nn.DataParallel(shape_encoder, device_ids=range(opts.ngpu)) 

        for crit in criterion.values(): 
            crit = crit.cuda() 
    

    ###########################################################
    ## Now we begin to test 
    ###########################################################         
    print('evaluation...')
    pr_at_k = test(test_processes[0], test_queue, text_encoder, shape_encoder, criterion, opts)

    #################################################################################
    # Finally, we kill all the processes 
    #################################################################################
    kill_processes(test_queue, test_processes)
Beispiel #7
0
def main():
    """Main text2voxel function.
    """
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.save_outputs is True and args.test is False:
        raise ValueError('Can only save outputs when testing, not training.')
    if args.validation:
        assert not args.test
    if args.test:
        assert args.ckpt_path is not None

    modify_args(args)

    print('----------------- CONFIG -------------------')
    pprint.pprint(cfg)

    # Save yaml
    os.makedirs(cfg.DIR.LOG_PATH, exist_ok=True)
    with open(os.path.join(cfg.DIR.LOG_PATH, 'run_cfg.yaml'), 'w') as out_yaml:
        yaml.dump(cfg, out_yaml, default_flow_style=False)

    # set up logger
    tf.logging.set_verbosity(tf.logging.INFO)

    try:
        with tf.Graph().as_default() as g:  # create graph
            # Load data
            inputs_dict, val_inputs_dict = get_inputs_dict(args)

            # Build network
            is_training = not args.test
            print('------------ BUILDING NETWORK -------------')
            network_class = models.load_model(cfg.NETWORK)
            net = network_class(inputs_dict, is_training)

            # Prefetching data processes
            #
            # Create worker and data queue for data processing. For training data, use
            # multiple processes to speed up the loading. For validation data, use 1
            # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
            # set up data queue and start enqueue
            np.random.seed(123)
            data_process_class = models.get_data_process_pairs(
                cfg.NETWORK, is_training)
            val_data_process_class = models.get_data_process_pairs(
                cfg.NETWORK, is_training=False)
            if is_training:
                global train_queue, train_processes
                train_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                train_processes = make_data_processes(data_process_class,
                                                      train_queue,
                                                      inputs_dict,
                                                      cfg.CONST.NUM_WORKERS,
                                                      repeat=True)
                if args.validation:
                    global val_queue, val_processes
                    val_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                    val_processes = make_data_processes(val_data_process_class,
                                                        val_queue,
                                                        val_inputs_dict,
                                                        1,
                                                        repeat=True)
            else:
                global test_queue, test_processes
                test_inputs_dict = val_inputs_dict
                test_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                test_processes = make_data_processes(val_data_process_class,
                                                     test_queue,
                                                     test_inputs_dict,
                                                     1,
                                                     repeat=False)

            # Create solver
            solver = get_solver(g, net, args, is_training)

            # Run solver
            if is_training:
                if args.validation:
                    if cfg.DIR.VAL_CKPT_PATH is not None:
                        assert train_processes[0].iters_per_epoch != 0
                        assert val_processes[0].iters_per_epoch != 0
                        solver.train(train_processes[0].iters_per_epoch,
                                     train_queue,
                                     val_processes[0].iters_per_epoch,
                                     val_queue=val_queue,
                                     val_inputs_dict=val_inputs_dict)
                    else:
                        if isinstance(net, LBA):
                            assert cfg.LBA.TEST_MODE is not None
                            assert cfg.LBA.TEST_MODE == 'shape'
                            assert train_processes[0].iters_per_epoch != 0
                            assert val_processes[0].iters_per_epoch != 0
                            solver.train(train_processes[0].iters_per_epoch,
                                         train_queue,
                                         val_processes[0].iters_per_epoch,
                                         val_queue=val_queue,
                                         val_inputs_dict=val_inputs_dict)
                        else:
                            assert train_processes[0].iters_per_epoch != 0
                            assert val_processes[0].iters_per_epoch != 0
                            solver.train(train_processes[0].iters_per_epoch,
                                         train_queue,
                                         val_processes[0].iters_per_epoch,
                                         val_queue=val_queue)
                else:
                    solver.train(train_processes[0].iters_per_epoch,
                                 train_queue)
            else:
                solver.test(test_processes[0],
                            test_queue,
                            num_minibatches=cfg.CONST.N_MINIBATCH_TEST,
                            save_outputs=args.save_outputs)
    finally:
        # Clean up the processes and queues
        if is_training:
            kill_processes(train_queue, train_processes)
            if args.validation:
                kill_processes(val_queue, val_processes)
        else:
            kill_processes(test_queue, test_processes)
def test_lba_process():
    from multiprocessing import Queue
    from lib.utils import print_sentences
    parser = argparse.ArgumentParser(description='test data process')
    parser.add_argument('--dataset',
                        dest='dataset',
                        help='dataset',
                        default='shapenet',
                        type=str)
    opts = parser.parse_args()
    opts.batch_size = 8
    opts.LBA_n_captions_per_model = 5
    opts.synth_embedding = False
    opts.probablematic_nrrd_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/problematic_nrrds_shapenet_unverified_256_filtered_div_with_err_textures.p'
    opts.LBA_model_type = 'STS'
    opts.val_data_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/processed_captions_val.p'
    opts.data_dir = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/nrrd_256_filter_div_32_solid/%s/%s.nrrd'

    caption_data = open_pickle(opts.val_data_path)
    data_queue = Queue(3)  # 3代表队列中存放的数据个数上线,达到上限,就会发生阻塞,直到队列中的数据被消费掉
    json_path = '/home/hxw/project_work_on/shape_research/datasets/text2shape/shapenet_dataset/shapenet_info/shapenet.json'

    pdb.set_trace()
    data_process = LBADataProcess(data_queue, caption_data, opts, repeat=True)
    data_process.start()
    caption_batch = data_queue.get()
    category_list = caption_batch['category_list']
    model_list = caption_batch['model_list']

    for k, v in caption_batch.items():
        if isinstance(v, list):
            print('key: ', k)
            print('value length: ', len(v))
        elif isinstance(v, np.ndarray):
            print('key: ', k)
            print('Value shape: ', v.shape)
        else:
            print('Other: ', k)
    print('')
    pdb.set_trace()
    """
    for i in range(len(category_list)):
        print('-------%03d------'%i)
        category = category_list[i] 
        model_id = model_list[i] 

        # generate sentencce 
        for j in range(data_process.n_captions_per_model):
            caption_idx = data_process.n_captions_per_model * i + j 
            caption = caption_batch['raw_embedding_batch'][caption_idx] 

            # print('caption:', caption)
            # print('converted caption: ')
            data_list = [{'raw_caption_embedding': caption}] 
            print_sentences(json_path, data_list)
            print('label: ', caption_batch['caption_label_batch'][caption_idx].item()) 

        print('category: ', category) 
        print('model id: ', model_id) 
    """
    pdb.set_trace()

    kill_processes(data_queue, [data_process])
Beispiel #9
0
def main():
    global opts
    opts.max_epochs = 1000
    opts.voxel_size = 32  # 4 x 32 x 32 x 32

    ###########################################################
    # Dataset and Make Data Loading Process
    ###########################################################
    inputs_dict, val_inputs_dict = get_inputs_dict(opts)

    # map category to label ('box-teal-h20-r20' -> 755)
    opts.category2label_dict = inputs_dict['class_labels']
    assert inputs_dict['vocab_size'] == val_inputs_dict['vocab_size']
    opts.vocab_size = inputs_dict['vocab_size']

    # Prefetching data processes
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    # set up data queue and start enqueue
    np.random.seed(123)
    data_process_for_class = models.get_data_process_pairs('LBA1',
                                                           opts,
                                                           is_training=True)
    val_data_process_for_class = models.get_data_process_pairs(
        'LBA1', opts, is_training=False)

    is_training = True
    if is_training:
        global train_queue, train_processes
        global val_queue, val_processes
        train_queue = Queue(opts.queue_capacity)
        train_processes = make_data_processes(data_process_for_class,
                                              train_queue,
                                              inputs_dict,
                                              opts,
                                              repeat=True)
        # set number of iterations for training
        opts.train_iters_per_epoch = train_processes[0].iters_per_epoch

        val_queue = Queue(opts.queue_capacity)
        # now we set number of workers to be 1
        opts.num_workers = 1
        val_processes = make_data_processes(val_data_process_for_class,
                                            val_queue,
                                            val_inputs_dict,
                                            opts,
                                            repeat=True)
        # set number of iterations for validation
        opts.val_iters_per_epoch = val_processes[0].iters_per_epoch

        #########################################################
        ## minibatch generator for the val/test phase for TEXT only.
        #########################################################
        opts.val_inputs_dict = val_inputs_dict
        # text_minibatch_generator_val = utils.val_phase_text_minibatch_generator(opts.val_inputs_dict, opts)

    else:
        global test_queue, test_processes
        test_inputs_dict = val_inputs_dict
        test_queue = Queue(opts.queue_capacity)
        opts.num_workers = 1
        test_processes = make_data_processes(val_data_process_for_class,
                                             test_queue,
                                             test_inputs_dict,
                                             opts,
                                             repeat=False)
        # set number of iterations for test
        opts.test_iters_per_epoch = test_processes[0].iters_per_epoch

    ###########################################################
    ## build network
    ###########################################################
    print('-------------building network--------------')
    network_class = models.load_model('LBA1')
    text_encoder = network_class['Text_Encoder'](
        opts.vocab_size, embedding_size=128, encoder_output_normalized=False)
    shape_encoder = network_class['Shape_Encoder'](
        num_classes=opts.num_classes, encoder_output_normalized=False)

    print('text encoder: ')
    print(text_encoder)
    print('shape encoder: ')
    print(shape_encoder)
    ###########################################################
    ## Training Criterion
    ###########################################################
    criterion = {}
    opts.LBA_NO_LBA = False
    if opts.LBA_NO_LBA is False:
        # by default, we set visit loss weith to be 0.25
        LBA_loss = loss.LBA_Loss(lmbda=0.25,
                                 LBA_model_type=opts.LBA_model_type,
                                 batch_size=opts.batch_size)
        criterion['LBA_loss'] = LBA_loss

    # classificaiton loss
    opts.LBA_Classificaiton = False
    if opts.LBA_Classificaiton is True:
        pass

    # metric loss
    opts.LBA_Metric = True
    if opts.LBA_Metric is True:
        opts.rho = 1.0  # set opts.rho to be 1.0 for combining LBA_loss and Metric loss
        Metric_loss = loss.Metric_Loss(opts,
                                       LBA_inverted_loss=True,
                                       LBA_normalized=False,
                                       LBA_max_norm=10.0)
        criterion['Metric_Loss'] = Metric_loss

    ## shift models to cuda
    if opts.cuda:
        print('shift model and criterion to GPU .. ')
        text_encoder = text_encoder.cuda()
        shape_encoder = shape_encoder.cuda()
        if opts.ngpu > 1:
            text_encoder = nn.DataParallel(text_encoder,
                                           device_ids=range(opts.ngpu))
            shape_encoder = nn.DataParallel(shape_encoder,
                                            device_ids=range(opts.ngpu))

        for crit in criterion.values():
            crit = crit.cuda()

    ###########################################################
    ## optimizer
    ###########################################################
    optimizer_text_encoder = optim.Adam(text_encoder.parameters(),
                                        lr=opts.learning_rate)
    optimizer_shape_encoder = optim.Adam(shape_encoder.parameters(),
                                         lr=opts.learning_rate)

    ################################################################################
    ## we begin to train our network
    ################################################################################
    # while True:
    #    caption_batch = train_queue.get()
    #    caption_batch = val_queue.get()
    best_val_acc = 0
    for epoch in range(opts.max_epochs):

        print('--------epoch {0}/{1}--------'.format(epoch, opts.max_epochs))
        # train for one epoch
        train(train_queue, text_encoder, shape_encoder, optimizer_text_encoder,
              optimizer_shape_encoder, criterion, epoch, opts)

        # validation for one epoch
        if epoch % 25 == 0:
            print('evaluation...')
            cur_val_acc = validation(val_queue, text_encoder, shape_encoder,
                                     criterion, epoch, opts)

            if cur_val_acc > best_val_acc:
                print(
                    'current val acc is bigger than previous best val acc, let us checkpointing ...'
                )
                path_checkpoint = '{0}/model_best.pth'.format(
                    opts.checkpoint_folder)
                checkpoint = {}
                if opts.ngpu > 1:
                    checkpoint[
                        'text_encoder'] = text_encoder.module.state_dict()
                    checkpoint[
                        'shape_encoder'] = shape_encoder.module.state_dict()
                else:
                    checkpoint['text_encoder'] = text_encoder.state_dict()
                    checkpoint['shape_encoder'] = shape_encoder.state_dict()

                print('save checkpoint to: ')
                print(path_checkpoint)
                torch.save(checkpoint, path_checkpoint)

    #################################################################################
    # Finally, we kill all the processes
    #################################################################################
    kill_processes(train_queue, train_processes)
    # kill validation process
    kill_processes(val_queue, val_processes)
    # if there is test process, also kill it
    if not is_training:
        kill_processes(test_queue, test_processes)