Пример #1
0
def get_dist(input_pdb, p1=None, p2=None):

    from readpdb import parse_pdb
    from mesh import create_mesh
    from assign_coord import assign_atomlist_to_mesh
    from assign_neighbor import assign_ngh
    from calculate_distance_matrix import calculate_distance_dm, calculate_distance_mm

    #Open the input file for parsing.#
    inf2 = input_pdb
    #parse_pdb function returns the details of all atoms in the pdb file with maximum and minimum value of coordinates in each direction. All these values gets stored in the 'param' list#
    param = parse_pdb(inf2)
    #close the file#
    #inf2.close();

    #Set the distance cutoff#
    cutoff = 8.0

    #Set the coarse mesh cell length in angstrom#
    cmesh_len = cutoff

    #function 'create_mesh' creates a 3D mesh such that all atoms should be contained within this mesh. Extra 'cutoff' length is added in each direction for boundary cells#
    cmesh = create_mesh(param, cmesh_len, cutoff)

    #getting the all atom coordinates and their identities like atom type, residue name, residue number and chain identifier. All this are stored int he 6th element of param list#
    atom_props = param[6]

    #Getting the number of total number cell in the mesh#
    totcell = cmesh[3] * cmesh[4] * cmesh[5]

    #initializing the cell_list variable. This stores the list of all atoms in each cell of the mesh#
    cell_list = [[] for n in range(0, totcell)]

    #Assign the atoms in each cell of the mesh and store them in cell_list#
    cell_list = assign_atomlist_to_mesh(cmesh, atom_props, cell_list)

    #get the neighbors of each cell in the mesh and store them in the neighbor dictionary "nghdict" #
    ngh_dict = assign_ngh(cmesh)

    #Calculate the dictance of all the atoms that are withion cutoff distance. This function stores the distances with atom identities in the outline string.#
    # @Abhilesh modified - If clause added
    if p1 == None and p2 == None:
        outline = calculate_distance_dm(cell_list, cmesh, ngh_dict, cutoff)
    else:
        outline = calculate_distance_mm(cell_list, cmesh, ngh_dict, cutoff, p1,
                                        p2)

    #Open the output file for writing the distances#
    dist_list = []
    #Write the outline variable to the file#
    for element in outline.split('\n')[:-1]:
        dist_list.append(element + '\n')

    return dist_list
Пример #2
0
    'plane': [],
    'cabinet': [],
    'car': [],
    'monitor': [],
    'couch': [],
    'cellphone': []
}

if __name__ == '__main__':
    # parameters
    param = create_parser()
    # use cuda
    device = torch.device('cuda:{}'.format(param.gpu_ids[0]) if param.use_cuda
                          and torch.cuda.is_available() else 'cpu')
    # predefined mesh
    init_mesh = create_mesh(param.obj_file, device=device)
    # model
    test_model = create_model(
        device,
        init_mesh,
        train=True,
        checkpoint=
        '/home/lihai/workspace/eval/pyg_model/save/model_epoch19_2019_08_30_20_00_25.pth'
    )

    # if param.muti_gpu:
    #     model = torch.nn.DataParallel(model, param.gpu_ids)

    f_total = 0
    total_count = 0
Пример #3
0
def run_cell_list(arguments):

    from readpdb import parse_pdb
    from mesh import create_mesh
    from assign_coord import assign_atomlist_to_mesh
    from assign_neighbor import assign_ngh
    from calculate_distance_matrix import calculate_distance

    #Getting the input file name#
    input_file = arguments[0]

    #Open the input file for parsing.#
    inf2 = open(input_file, 'r')

    print "Reading file: ", input_file

    #parse_pdb function returns the details of all atoms in the pdb file with maximum and minimum value of coordinates in each direction. All these values gets stored in the 'param' list#
    param = parse_pdb(inf2, arguments[3], arguments[4])

    #close the file#
    inf2.close()

    #Set the distance cutoff#
    cutoff = float(arguments[2])

    #Set the coarse mesh cell length in angstrom#
    cmesh_len = cutoff

    #function 'create_mesh' creates a 3D mesh such that all atoms should be contained within this mesh. Extra 'cutoff' length is added in each direction for boundary cells#
    cmesh = create_mesh(param, cmesh_len, cutoff)

    #getting the all atom coordinates and their identities like atom type, residue name, residue number and chain identifier. All this are stored int he 6th element of param list#
    atom_props = param[6]

    #Getting the number of total number cell in the mesh#
    totcell = cmesh[3] * cmesh[4] * cmesh[5]

    #initializing the cell_list variable. This stores the list of all atoms in each cell of the mesh#
    cell_list = [[] for n in range(0, totcell)]

    #Assign the atoms in each cell of the mesh and store them in cell_list#
    cell_list = assign_atomlist_to_mesh(cmesh, atom_props, cell_list)

    #get the neighbors of each cell in the mesh and store them in the neighbor dictionary "nghdict" #
    ngh_dict = assign_ngh(cmesh)

    print "Calculating distances...\n"
    #Calculate the dictance of all the atoms that are withion cutoff distance. This function stores the distances with atom identities in the outline string.#
    outline = calculate_distance(cell_list, cmesh, ngh_dict, cutoff)

    #Getting the output file name#
    output_file = arguments[1]

    print "Writing the output file...\n"

    #Open the output file for writing the distances#
    fname = output_file
    outf = open(fname, 'w')
    #Write the outline variable to the file#
    outf.writelines(outline)
    outf.close()

    return outline
Пример #4
0
def train(param):
    # use cuda
    device = torch.device(
        'cuda:{}'.format(param.gpu_ids[0]) if param.use_cuda and torch.cuda.is_available() else 'cpu')
    # dataset
    train_dataset = creare_dataset(
        param.data_type, param.data_root, categories=param.train_category, param=param, train=True)
    test_dataset = creare_dataset(
        param.data_type, param.data_root, categories=param.eval_category, param=param, train=False)
    # dataloader
    train_loader = DataLoader(train_dataset.dataset, batch_size=param.batch_size, shuffle=True, num_workers=param.num_workers, pin_memory=param.pin_mem)
    test_loader = DataLoader(test_dataset.dataset, batch_size=param.batch_size, shuffle=False, num_workers=param.num_workers, pin_memory=param.pin_mem)
    # logger
    logger = None
    if param.use_logger:
        logger = create_logger(param.log_path + '_' + param.name)
    # predefined mesh
    host_init_mesh = create_mesh(param)
    device_init_mesh = create_mesh(param, device)
    # model
    model = create_model(host_init_mesh, param, train=True)
    # checkpoint
    saver = create_saver(param.save_path + '_' + param.name)
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=param.lr, weight_decay=param.weight_decay)
    # iteration
    n_iter = 0
    # parallel setting
    if param.muti_gpu:
        model = torch.nn.DataParallel(model, param.gpu_ids)
    model = model.to(device)

    for epoch in range(param.epoch):
        print('start traning for epoch ', epoch)
        for data in tqdm(train_loader):

            # read data
            data = data.to(device)
            optimizer.zero_grad()

            # collect input
            input_dict = make_input_dict(param, data)
            output_dict = {}

            points, pre_coords, coords, edges, faces, vmasks, output_img = model(input_dict['render_img'], input_dict['proj_mat'], logger, n_iter)
            output_dict['pre_coords'] = pre_coords
            output_dict['points'] = points
            output_dict['coords'] = coords
            output_dict['edges'] = edges
            output_dict['faces'] = faces
            output_dict['vmasks'] = vmasks
            output_dict['output_img'] = output_img

            total_loss, loss_dict, gt_sample_pos = compute_loss(param, device_init_mesh, input_dict, output_dict)

            total_loss.backward()
            optimizer.step()

            n_iter += 1

            if logger is not None:
                if n_iter % param.loss_freq_iter == 0:
                    logger.add_loss('total loss', total_loss.item(), n_iter)
                    logger.add_losses('losses', loss_dict, n_iter)

                if n_iter % param.check_freq_iter == 0:
                    logger.add_gradcheck(model.named_parameters(), n_iter)
                    logger.add_image('input image', input_dict['render_img'], n_iter)
                    
                    # logger.add_projectionU('gt_project', gt_sample_pos, input_dict['proj_mat'] n_iter)
                    if input_dict['proj_mat'] is not None:
                        logger.add_projectionU('gt_project', input_dict['sample'][0].unsqueeze(0), input_dict['proj_mat'], n_iter)
                        for i in range(len(output_dict['coords'])):
                            logger.add_projectionU('project{}'.format(i), output_dict['coords'][0], input_dict['proj_mat'], n_iter)

                    else:
                        logger.add_projection('gt_project', input_dict['sample'][0].unsqueeze(0), n_iter)
                        for i in range(len(output_dict['coords'])):
                            logger.add_projection('project{}'.format(i), output_dict['coords'][0], n_iter)

                if n_iter % param.train_mesh_freq_iter == 0:
                    for i in range(len(output_dict['coords'])):
                        logger.save_mesh('mesh{}'.format(i),  output_dict['pre_coords'][i][0], n_iter, output_dict['faces'][i][0])
                        logger.save_mesh('mesh{}'.format(i),  output_dict['coords'][i][0], n_iter, output_dict['faces'][i][0])
                        
                    logger.save_mesh('gt_pcl', input_dict['sample'][0], n_iter)
                    logger.save_mesh('detail_pcl', output_dict['points'][0], n_iter)

        print('finish traning for epoch ', epoch)
        if isinstance(model, torch.nn.DataParallel):
            saver.save_model(model.module.state_dict(), epoch)
        else:
            saver.save_model(model.state_dict(), epoch)
        # saver.save_optimizer(optimizer.state_dict(), epoch)

        print('finish saving checkpoint for epoch ', epoch)
        if n_iter % param.eval_freq_epoch == 0:
            run_eval(
                test_loader, model, host_init_mesh, logger, epoch, device, param)