예제 #1
0
def train_net():
    '''Main training function'''
    # Set up the model and the solver
    NetClass = load_model(cfg.CONST.NETWORK_CLASS)

    # print('Network definition: \n')
    # print(inspect.getsource(NetClass.network_definition))
    net = NetClass()

    # Check that single view reconstruction net is not used for multi view
    # reconstruction.
    if net.is_x_tensor4 and cfg.CONST.N_VIEWS > 1:
        raise ValueError('Do not set the config.CONST.N_VIEWS > 1 when using' \
                         'single-view reconstruction network')

    # Prefetching data processes
    #
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    global train_queue, val_queue, train_processes, val_processes
    train_queue = Queue(cfg.QUEUE_SIZE)
    val_queue = Queue(cfg.QUEUE_SIZE)

    train_processes = make_data_processes(
        train_queue,
        category_model_id_pair(dataset_portion=cfg.TRAIN.DATASET_PORTION),
        cfg.TRAIN.NUM_WORKER,
        repeat=True)
    val_processes = make_data_processes(
        val_queue,
        category_model_id_pair(dataset_portion=cfg.TEST.DATASET_PORTION),
        1,
        repeat=True,
        train=False)

    import torch.cuda
    if torch.cuda.is_available():
        net.cuda()

    # print the queue
    # print(train_queue)
    # print(val_queue)

    # Generate the solver
    solver = Solver(net)

    # Train the network
    solver.train(train_queue, val_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(val_queue, val_processes)
예제 #2
0
def test_net():
    ''' Evaluate the network '''
    # Make result directory and the result file.
    result_dir = os.path.join(cfg.DIR.OUT_PATH, cfg.TEST.EXP_NAME)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    result_fn = os.path.join(result_dir, 'result.mat')

    print("Exp file will be written to: " + result_fn)

    # Make a network and load weights
    NetworkClass = load_model(cfg.CONST.NETWORK_CLASS)
    print('Network definition: \n')
    print(inspect.getsource(NetworkClass.network_definition))
    net = NetworkClass(compute_grad=False)
    net.load(cfg.CONST.WEIGHTS)
    solver = Solver(net)

    # set constants
    batch_size = cfg.CONST.BATCH_SIZE

    # set up testing data process. We make only one prefetching process. The
    # process will return one batch at a time.
    queue = Queue(cfg.QUEUE_SIZE)
    data_pair = category_model_id_pair(
        dataset_portion=cfg.TEST.DATASET_PORTION)
    processes = make_data_processes(queue,
                                    data_pair,
                                    1,
                                    repeat=False,
                                    train=False)
    num_data = len(processes[0].data_paths)
    num_batch = int(num_data / batch_size)

    # prepare result container
    results = {'cost': np.zeros(num_batch)}
    for thresh in cfg.TEST.VOXEL_THRESH:
        results[str(thresh)] = np.zeros((num_batch, batch_size, 5))

    # Get all test data
    batch_idx = 0
    for batch_img, batch_voxel in get_while_running(processes[0], queue):
        if batch_idx == num_batch:
            break

        pred, loss, activations = solver.test_output(batch_img, batch_voxel)
        print('%d/%d, cost is: %f' % (batch_idx, num_batch, loss))

        for i, thresh in enumerate(cfg.TEST.VOXEL_THRESH):
            for j in range(batch_size):
                r = evaluate_voxel_prediction(pred[j, ...],
                                              batch_voxel[j, ...], thresh)
                results[str(thresh)][batch_idx, j, :] = r

        # record result for the batch
        results['cost'][batch_idx] = float(loss)
        batch_idx += 1

    print('Total loss: %f' % np.mean(results['cost']))
    sio.savemat(result_fn, results)
def train_net():

    # Set up the model and the solver
    my_net = My_ResidualGRUNet()

    # Generate the solver
    solver = Solver(my_net)

    # Load the global variables
    global train_queue, validation_queue, train_processes, val_processes

    # Initialize the queues
    train_queue = Queue(
        15)  # maximum number of minibatches that can be put in a data queue)
    validation_queue = Queue(15)

    # Train on 80 percent of the data
    train_dataset_portion = [0, 0.8]

    # Validate on 20 percent of the data
    test_dataset_portion = [0.8, 1]

    # Establish the training procesesses
    train_processes = make_data_processes(
        train_queue,
        category_model_id_pair(dataset_portion=train_dataset_portion),
        1,
        repeat=True)

    # Establish the validation procesesses
    val_processes = make_data_processes(
        validation_queue,
        category_model_id_pair(dataset_portion=test_dataset_portion),
        1,
        repeat=True,
        train=False)

    # Train the network
    solver.train(train_queue, validation_queue)

    # Cleanup the processes and the queue.
    kill_processes(train_queue, train_processes)
    kill_processes(validation_queue, val_processes)
예제 #4
0
def demo(args):
    ''' Evaluate the network '''

    # Make a network and load weights
    NetworkClass = load_model(cfg.CONST.NETWORK_CLASS)
    print('Network definition: \n')
    print(inspect.getsource(NetworkClass.network_definition))
    net = NetworkClass(compute_grad=False)
    net.load(cfg.CONST.WEIGHTS)
    solver = Solver(net)

    # set up testing data process. We make only one prefetching process. The
    # process will return one batch at a time.
    queue = Queue(cfg.QUEUE_SIZE)
    data_pair = category_model_id_pair(dataset_portion=cfg.TEST.DATASET_PORTION)
    processes = make_data_processes(queue, data_pair, 1, repeat=False, train=False)
    num_data = len(processes[0].data_paths)
    num_batch = int(num_data / args.batch_size)

    # Get all test data
    batch_idx = 0
    for batch_img, batch_voxel in get_while_running(processes[0], queue):
        if batch_idx == num_batch:
            break

        pred, loss, activations = solver.test_output(batch_img, batch_voxel)

        if (batch_idx < args.exportNum):
            # Save the prediction to an OBJ file (mesh file).
            print('saving {}/{}'.format(batch_idx, args.exportNum - 1))
            voxel2obj('out/prediction_{}b_{}.obj'.format(args.batch_size, batch_idx),
                      pred[0, :, 1, :, :] > cfg.TEST.VOXEL_THRESH)
        else:
            break

        batch_idx += 1

    if args.file:
        # Use meshlab or other mesh viewers to visualize the prediction.
        # For Ubuntu>=14.04, you can install meshlab using
        # `sudo apt-get install meshlab`
        if cmd_exists('meshlab'):
            call(['meshlab', 'obj/{}.obj'.format(args.file)])
        else:
            print('Meshlab not found: please use visualization of your choice to view %s' %
                  args.file)
예제 #5
0
def demo(args):
    ''' Evaluate the network '''

    # Make a network and load weights
    NetworkClass = load_model(cfg.CONST.NETWORK_CLASS)
    print('Network definition: \n')
    print(inspect.getsource(NetworkClass.network_definition))
    net = NetworkClass(compute_grad=False)
    net.load(cfg.CONST.WEIGHTS)
    solver = Solver(net)

    # set up testing data process. We make only one prefetching process. The
    # process will return one batch at a time.
    queue = Queue(cfg.QUEUE_SIZE)
    data_pair = category_model_id_pair(
        dataset_portion=cfg.TEST.DATASET_PORTION)
    processes = make_data_processes(queue,
                                    data_pair,
                                    1,
                                    repeat=False,
                                    train=False)
    num_data = len(processes[0].data_paths)
    num_batch = int(num_data / args.batch_size)

    # Get all test data
    batch_idx = 0
    for batch_img, batch_voxel in get_while_running(processes[0], queue):
        if batch_idx == num_batch:
            break

        pred, loss, activations = solver.test_output(batch_img, batch_voxel)

        if (batch_idx < args.exportNum):
            # Save the prediction to an OBJ file (mesh file).
            print('saving {}/{}'.format(batch_idx, args.exportNum - 1))
            # voxel2obj('out/prediction_{}b_{}.obj'.format(args.batch_size, batch_idx),
            #           pred[0, :, 1, :, :] > cfg.TEST.VOXEL_THRESH)
        else:
            break

        batch_idx += 1
def main():
    global opts
    opts.voxel_size = 32 # 4 x 32 x 32 x 32 

    ###########################################################
    # Dataset and Make Data Loading Process 
    ###########################################################
    inputs_dict, test_inputs_dict = get_inputs_dict(opts)

    # map category to label ('box-teal-h20-r20' -> 755)
    opts.category2label_dict = inputs_dict['class_labels']
    assert inputs_dict['vocab_size'] == test_inputs_dict['vocab_size']
    opts.vocab_size = inputs_dict['vocab_size']

    # Prefetching data processes 
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    # set up data queue and start enqueue
    np.random.seed(123) 
    test_data_process_for_class = models.get_data_process_pairs('LBA1', opts, is_training=False) 
    
    global test_queue, test_processes 
    test_queue = Queue(opts.queue_capacity) 
    opts.num_workers = 1 
    test_processes = make_data_processes(test_data_process_for_class, test_queue, test_inputs_dict, opts, repeat=False) 

    ###########################################################
    ## build network, loading pretrained model, shift to GPU 
    ###########################################################
    print('-------------building network--------------')
    network_class = models.load_model('LBA1')
    text_encoder = network_class['Text_Encoder'](opts.vocab_size, embedding_size=128, encoder_output_normalized=False) 
    shape_encoder = network_class['Shape_Encoder'](num_classes=opts.num_classes, encoder_output_normalized=False) 

    print('text encoder: ')
    print(text_encoder)
    print('shape encoder: ')
    print(shape_encoder)

    print('loading checkpoints....')
    if opts.pretrained_model != '':
        print('loading pretrained model from {0}'.format(opts.pretrained_model))
        checkpoint = torch.load(opts.pretrained_model)
        text_encoder.load_state_dict(checkpoint['text_encoder'])
        shape_encoder.load_state_dict(checkpoint['shape_encoder'])
    else: 
        assert ValueError('please input the path to pretrained model.')

    ###########################################################
    ## Training Criterion 
    ###########################################################
    criterion = {}
    opts.LBA_NO_LBA = False 
    if opts.LBA_NO_LBA is False: 
        # by default, we set visit loss weith to be 0.25
        LBA_loss = loss.LBA_Loss(lmbda=0.25, LBA_model_type=opts.LBA_model_type, batch_size=opts.batch_size)
        criterion['LBA_loss'] = LBA_loss

    # classificaiton loss 
    opts.LBA_Classificaiton = False 
    if opts.LBA_Classificaiton is True: 
        pass 

    # metric loss
    opts.LBA_Metric = True 
    if opts.LBA_Metric is True: 
        opts.rho = 1.0 # set opts.rho to be 1.0 for combining LBA_loss and Metric loss 
        Metric_loss = loss.Metric_Loss(opts, LBA_inverted_loss=True, LBA_normalized=False, LBA_max_norm=10.0)
        criterion['Metric_Loss'] = Metric_loss

    ## shift models to cuda 
    if opts.cuda:  
        print('shift model and criterion to GPU .. ')
        text_encoder = text_encoder.cuda() 
        shape_encoder = shape_encoder.cuda() 
        if opts.ngpu > 1:
            text_encoder = nn.DataParallel(text_encoder, device_ids=range(opts.ngpu)) 
            shape_encoder = nn.DataParallel(shape_encoder, device_ids=range(opts.ngpu)) 

        for crit in criterion.values(): 
            crit = crit.cuda() 
    

    ###########################################################
    ## Now we begin to test 
    ###########################################################         
    print('evaluation...')
    pr_at_k = test(test_processes[0], test_queue, text_encoder, shape_encoder, criterion, opts)

    #################################################################################
    # Finally, we kill all the processes 
    #################################################################################
    kill_processes(test_queue, test_processes)
예제 #7
0
파일: main.py 프로젝트: zehuiw/text2shape
def main():
    """Main text2voxel function.
    """
    args = parse_args()

    print('Called with args:')
    print(args)

    if args.save_outputs is True and args.test is False:
        raise ValueError('Can only save outputs when testing, not training.')
    if args.validation:
        assert not args.test
    if args.test:
        assert args.ckpt_path is not None

    modify_args(args)

    print('----------------- CONFIG -------------------')
    pprint.pprint(cfg)

    # Save yaml
    os.makedirs(cfg.DIR.LOG_PATH, exist_ok=True)
    with open(os.path.join(cfg.DIR.LOG_PATH, 'run_cfg.yaml'), 'w') as out_yaml:
        yaml.dump(cfg, out_yaml, default_flow_style=False)

    # set up logger
    tf.logging.set_verbosity(tf.logging.INFO)

    try:
        with tf.Graph().as_default() as g:  # create graph
            # Load data
            inputs_dict, val_inputs_dict = get_inputs_dict(args)

            # Build network
            is_training = not args.test
            print('------------ BUILDING NETWORK -------------')
            network_class = models.load_model(cfg.NETWORK)
            net = network_class(inputs_dict, is_training)

            # Prefetching data processes
            #
            # Create worker and data queue for data processing. For training data, use
            # multiple processes to speed up the loading. For validation data, use 1
            # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
            # set up data queue and start enqueue
            np.random.seed(123)
            data_process_class = models.get_data_process_pairs(
                cfg.NETWORK, is_training)
            val_data_process_class = models.get_data_process_pairs(
                cfg.NETWORK, is_training=False)
            if is_training:
                global train_queue, train_processes
                train_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                train_processes = make_data_processes(data_process_class,
                                                      train_queue,
                                                      inputs_dict,
                                                      cfg.CONST.NUM_WORKERS,
                                                      repeat=True)
                if args.validation:
                    global val_queue, val_processes
                    val_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                    val_processes = make_data_processes(val_data_process_class,
                                                        val_queue,
                                                        val_inputs_dict,
                                                        1,
                                                        repeat=True)
            else:
                global test_queue, test_processes
                test_inputs_dict = val_inputs_dict
                test_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
                test_processes = make_data_processes(val_data_process_class,
                                                     test_queue,
                                                     test_inputs_dict,
                                                     1,
                                                     repeat=False)

            # Create solver
            solver = get_solver(g, net, args, is_training)

            # Run solver
            if is_training:
                if args.validation:
                    if cfg.DIR.VAL_CKPT_PATH is not None:
                        assert train_processes[0].iters_per_epoch != 0
                        assert val_processes[0].iters_per_epoch != 0
                        solver.train(train_processes[0].iters_per_epoch,
                                     train_queue,
                                     val_processes[0].iters_per_epoch,
                                     val_queue=val_queue,
                                     val_inputs_dict=val_inputs_dict)
                    else:
                        if isinstance(net, LBA):
                            assert cfg.LBA.TEST_MODE is not None
                            assert cfg.LBA.TEST_MODE == 'shape'
                            assert train_processes[0].iters_per_epoch != 0
                            assert val_processes[0].iters_per_epoch != 0
                            solver.train(train_processes[0].iters_per_epoch,
                                         train_queue,
                                         val_processes[0].iters_per_epoch,
                                         val_queue=val_queue,
                                         val_inputs_dict=val_inputs_dict)
                        else:
                            assert train_processes[0].iters_per_epoch != 0
                            assert val_processes[0].iters_per_epoch != 0
                            solver.train(train_processes[0].iters_per_epoch,
                                         train_queue,
                                         val_processes[0].iters_per_epoch,
                                         val_queue=val_queue)
                else:
                    solver.train(train_processes[0].iters_per_epoch,
                                 train_queue)
            else:
                solver.test(test_processes[0],
                            test_queue,
                            num_minibatches=cfg.CONST.N_MINIBATCH_TEST,
                            save_outputs=args.save_outputs)
    finally:
        # Clean up the processes and queues
        if is_training:
            kill_processes(train_queue, train_processes)
            if args.validation:
                kill_processes(val_queue, val_processes)
        else:
            kill_processes(test_queue, test_processes)
예제 #8
0
def test_net():
    ''' Evaluate the network '''
    # Make result directory and the result file.
    result_dir = os.path.join(cfg.DIR.OUT_PATH, cfg.TEST.EXP_NAME)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    result_fn = os.path.join(result_dir, 'result.mat')

    print("Exp file will be written to: " + result_fn)

    # Make a network and load weights
    NetworkClass = load_model(cfg.CONST.NETWORK_CLASS)

    #print('Network definition: \n')
    #print(inspect.getsource(NetworkClass.network_definition))

    net = NetworkClass()
    
    net.cuda()
    
    solver = Solver(net)
    solver.load(cfg.CONST.WEIGHTS)

    # set constants
    batch_size = cfg.CONST.BATCH_SIZE

    # set up testing data process. We make only one prefetching process. The
    # process will return one batch at a time.
    queue = Queue(cfg.QUEUE_SIZE)
    data_pair = category_model_id_pair(dataset_portion=cfg.TEST.DATASET_PORTION)
    processes = make_data_processes(queue, data_pair, 1, repeat=False, train=False)

    num_data = len(processes[0].data_paths)
    num_batch = int(num_data / batch_size)

    # prepare result container
    results = {'cost': np.zeros(num_batch),
               'mAP': np.zeros((num_batch, batch_size))}
    # Save results for various thresholds
    for thresh in cfg.TEST.VOXEL_THRESH:
        results[str(thresh)] = np.zeros((num_batch, batch_size, 5))

    # Get all test data
    batch_idx = 0
    for batch_img, batch_voxel in get_while_running(processes[0], queue):
        if batch_idx == num_batch:
            break

        #activations is a list of torch.cuda.FloatTensor
        pred, loss, activations = solver.test_output(batch_img, batch_voxel)
        
        #convert pytorch tensor to numpy array
        pred = pred.data.cpu().numpy()
        loss = loss.data.cpu().numpy()

        for j in range(batch_size):
            # Save IoU per thresh
            for i, thresh in enumerate(cfg.TEST.VOXEL_THRESH):
                r = evaluate_voxel_prediction(pred[j, ...], batch_voxel[j, ...], thresh)
                results[str(thresh)][batch_idx, j, :] = r

            # Compute AP
            precision = sklearn.metrics.average_precision_score(
                batch_voxel[j, 1].flatten(), pred[j, 1].flatten())

            results['mAP'][batch_idx, j] = precision

        # record result for the batch
        results['cost'][batch_idx] = float(loss)
        print('%d/%d, costs: %f, mAP: %f' %
                (batch_idx, num_batch, loss, np.mean(results['mAP'][batch_idx])))
        batch_idx += 1


    print('Total loss: %f' % np.mean(results['cost']))
    print('Total mAP: %f' % np.mean(results['mAP']))

    sio.savemat(result_fn, results)
예제 #9
0
def main():
    global opts
    opts.max_epochs = 1000
    opts.voxel_size = 32  # 4 x 32 x 32 x 32

    ###########################################################
    # Dataset and Make Data Loading Process
    ###########################################################
    inputs_dict, val_inputs_dict = get_inputs_dict(opts)

    # map category to label ('box-teal-h20-r20' -> 755)
    opts.category2label_dict = inputs_dict['class_labels']
    assert inputs_dict['vocab_size'] == val_inputs_dict['vocab_size']
    opts.vocab_size = inputs_dict['vocab_size']

    # Prefetching data processes
    # Create worker and data queue for data processing. For training data, use
    # multiple processes to speed up the loading. For validation data, use 1
    # since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
    # set up data queue and start enqueue
    np.random.seed(123)
    data_process_for_class = models.get_data_process_pairs('LBA1',
                                                           opts,
                                                           is_training=True)
    val_data_process_for_class = models.get_data_process_pairs(
        'LBA1', opts, is_training=False)

    is_training = True
    if is_training:
        global train_queue, train_processes
        global val_queue, val_processes
        train_queue = Queue(opts.queue_capacity)
        train_processes = make_data_processes(data_process_for_class,
                                              train_queue,
                                              inputs_dict,
                                              opts,
                                              repeat=True)
        # set number of iterations for training
        opts.train_iters_per_epoch = train_processes[0].iters_per_epoch

        val_queue = Queue(opts.queue_capacity)
        # now we set number of workers to be 1
        opts.num_workers = 1
        val_processes = make_data_processes(val_data_process_for_class,
                                            val_queue,
                                            val_inputs_dict,
                                            opts,
                                            repeat=True)
        # set number of iterations for validation
        opts.val_iters_per_epoch = val_processes[0].iters_per_epoch

        #########################################################
        ## minibatch generator for the val/test phase for TEXT only.
        #########################################################
        opts.val_inputs_dict = val_inputs_dict
        # text_minibatch_generator_val = utils.val_phase_text_minibatch_generator(opts.val_inputs_dict, opts)

    else:
        global test_queue, test_processes
        test_inputs_dict = val_inputs_dict
        test_queue = Queue(opts.queue_capacity)
        opts.num_workers = 1
        test_processes = make_data_processes(val_data_process_for_class,
                                             test_queue,
                                             test_inputs_dict,
                                             opts,
                                             repeat=False)
        # set number of iterations for test
        opts.test_iters_per_epoch = test_processes[0].iters_per_epoch

    ###########################################################
    ## build network
    ###########################################################
    print('-------------building network--------------')
    network_class = models.load_model('LBA1')
    text_encoder = network_class['Text_Encoder'](
        opts.vocab_size, embedding_size=128, encoder_output_normalized=False)
    shape_encoder = network_class['Shape_Encoder'](
        num_classes=opts.num_classes, encoder_output_normalized=False)

    print('text encoder: ')
    print(text_encoder)
    print('shape encoder: ')
    print(shape_encoder)
    ###########################################################
    ## Training Criterion
    ###########################################################
    criterion = {}
    opts.LBA_NO_LBA = False
    if opts.LBA_NO_LBA is False:
        # by default, we set visit loss weith to be 0.25
        LBA_loss = loss.LBA_Loss(lmbda=0.25,
                                 LBA_model_type=opts.LBA_model_type,
                                 batch_size=opts.batch_size)
        criterion['LBA_loss'] = LBA_loss

    # classificaiton loss
    opts.LBA_Classificaiton = False
    if opts.LBA_Classificaiton is True:
        pass

    # metric loss
    opts.LBA_Metric = True
    if opts.LBA_Metric is True:
        opts.rho = 1.0  # set opts.rho to be 1.0 for combining LBA_loss and Metric loss
        Metric_loss = loss.Metric_Loss(opts,
                                       LBA_inverted_loss=True,
                                       LBA_normalized=False,
                                       LBA_max_norm=10.0)
        criterion['Metric_Loss'] = Metric_loss

    ## shift models to cuda
    if opts.cuda:
        print('shift model and criterion to GPU .. ')
        text_encoder = text_encoder.cuda()
        shape_encoder = shape_encoder.cuda()
        if opts.ngpu > 1:
            text_encoder = nn.DataParallel(text_encoder,
                                           device_ids=range(opts.ngpu))
            shape_encoder = nn.DataParallel(shape_encoder,
                                            device_ids=range(opts.ngpu))

        for crit in criterion.values():
            crit = crit.cuda()

    ###########################################################
    ## optimizer
    ###########################################################
    optimizer_text_encoder = optim.Adam(text_encoder.parameters(),
                                        lr=opts.learning_rate)
    optimizer_shape_encoder = optim.Adam(shape_encoder.parameters(),
                                         lr=opts.learning_rate)

    ################################################################################
    ## we begin to train our network
    ################################################################################
    # while True:
    #    caption_batch = train_queue.get()
    #    caption_batch = val_queue.get()
    best_val_acc = 0
    for epoch in range(opts.max_epochs):

        print('--------epoch {0}/{1}--------'.format(epoch, opts.max_epochs))
        # train for one epoch
        train(train_queue, text_encoder, shape_encoder, optimizer_text_encoder,
              optimizer_shape_encoder, criterion, epoch, opts)

        # validation for one epoch
        if epoch % 25 == 0:
            print('evaluation...')
            cur_val_acc = validation(val_queue, text_encoder, shape_encoder,
                                     criterion, epoch, opts)

            if cur_val_acc > best_val_acc:
                print(
                    'current val acc is bigger than previous best val acc, let us checkpointing ...'
                )
                path_checkpoint = '{0}/model_best.pth'.format(
                    opts.checkpoint_folder)
                checkpoint = {}
                if opts.ngpu > 1:
                    checkpoint[
                        'text_encoder'] = text_encoder.module.state_dict()
                    checkpoint[
                        'shape_encoder'] = shape_encoder.module.state_dict()
                else:
                    checkpoint['text_encoder'] = text_encoder.state_dict()
                    checkpoint['shape_encoder'] = shape_encoder.state_dict()

                print('save checkpoint to: ')
                print(path_checkpoint)
                torch.save(checkpoint, path_checkpoint)

    #################################################################################
    # Finally, we kill all the processes
    #################################################################################
    kill_processes(train_queue, train_processes)
    # kill validation process
    kill_processes(val_queue, val_processes)
    # if there is test process, also kill it
    if not is_training:
        kill_processes(test_queue, test_processes)