def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) print(colored(str(config), 'cyan')) eval_flag = config['eval'] if not eval_flag: #train mode : fresh or reload current_log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") current_log_dir = os.path.join('../Experiments/', current_log_dir) else: #eval mode : save result plys if args.load_checkpoint_dir: current_log_dir = '../Eval' else: print( colored( '*****please provide checkpoint file path to reload!*****', 'red')) return print(colored('logs will be saved in:{}'.format(current_log_dir), 'yellow')) if args.load_checkpoint_dir: load_checkpoint_dir = os.path.join('../Experiments/', args.load_checkpoint_dir, 'chkpt') #load last checkpoint print( colored('load_checkpoint_dir: {}'.format(load_checkpoint_dir), 'red')) save_checkpoint_dir = os.path.join(current_log_dir, 'chkpt') print( colored('save_checkpoint_dir: {}\n'.format(save_checkpoint_dir), 'yellow')) if not os.path.exists(save_checkpoint_dir): os.makedirs(save_checkpoint_dir) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) print(template_file_path) visualize = config['visualize'] lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config['workers_thread'] opt = config['optimizer'] batch_size = config['batch_size'] val_losses, accs, durations = [], [], [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.cuda.is_available(): print(colored('\n...cuda is available...\n', 'green')) else: print(colored('\n...cuda is NOT available...\n', 'red')) ds_factors = config['downsampling_factors'] print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, ds_factors) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print(colored('number of nodes in encoder : {}'.format(num_nodes), 'blue')) if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] print('*** data loaded from {} ***'.format(data_dir)) dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term) dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=workers_thread) print("x :\n{} for dataset[0] element".format(dataset[0])) print(colored(train_loader, 'red')) print('Loading Model : \n') start_epoch = 1 coma = ComaVAE(dataset, config, D_t, U_t, A_t, num_nodes) tbSummWriter = SummaryWriter(current_log_dir) print_model_summary = False if print_model_summary: print(coma) mrkdwn = str('<pre><code>' + str(coma) + '</code></pre>') tbSummWriter.add_text('tag2', mrkdwn, global_step=None, walltime=None) #write network architecture into text file logfile = os.path.join(current_log_dir, 'coma.txt') my_data_file = open(logfile, 'w') my_data_file.write(str(coma)) my_data_file.close() if opt == 'adam': optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') if args.load_checkpoint_dir: #to load the newest saved checkpoint to_back = os.getcwd() os.chdir(load_checkpoint_dir) chkpt_list = sorted(os.listdir(os.getcwd()), key=os.path.getctime) os.chdir(to_back) checkpoint_file = chkpt_list[-1] logfile = os.path.join(current_log_dir, 'loadedfrom.txt') my_data_file = open(logfile, 'w') my_data_file.write(str(load_checkpoint_dir)) my_data_file.close() print( colored( '\n\nloading Newest checkpoint : {}\n'.format(checkpoint_file), 'red')) if checkpoint_file: checkpoint = torch.load( os.path.join(load_checkpoint_dir, checkpoint_file)) start_epoch = checkpoint['epoch_num'] coma.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #To find if this is fixed in pytorch for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) coma.to(device) for i, dt in enumerate(train_loader): dt = dt.to(device) graphstr = pms.summary(coma, dt, batch_size=-1, show_input=True, show_hierarchical=False) if print_model_summary: print(graphstr) print(colored('dt in enumerate(train_loader):{} '.format(dt), 'green')) #write network architecture into text file logfile = os.path.join(current_log_dir, 'pms.txt') my_data_file = open(logfile, 'w') my_data_file.write(graphstr) my_data_file.close() mrkdwn = str('<pre><code>' + graphstr + '</code></pre>') tbSummWriter.add_text('tag', mrkdwn, global_step=None, walltime=None) break #for one sample only if eval_flag and args.load_checkpoint_dir: evaluatedFrom = 'predictedPlys_' + checkpoint_file output_dir = os.path.join('../Experiments/', args.load_checkpoint_dir, evaluatedFrom) #load last checkpoint val_loss = evaluate(coma, test_loader, dataset_test, template_mesh, device, visualize=True, output_dir=output_dir) print('val loss', val_loss) return best_val_loss = float('inf') val_loss_history = [] for epoch in range(start_epoch, total_epochs + 1): print("Training for epoch ", epoch) print('dataset.len : {}'.format(len(dataset))) train_loss = train(coma, train_loader, len(dataset), optimizer, device) val_loss = evaluate(coma, test_loader, dataset_test, template_mesh, device, visualize=False, output_dir='') #train without visualization sample_latent_space(coma, epoch, device, template_mesh, current_log_dir) tbSummWriter.add_scalar('Loss/train', train_loss, epoch) tbSummWriter.add_scalar('Val Loss/train', val_loss, epoch) tbSummWriter.add_scalar('learning_rate', lr, epoch) print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ', val_loss) if val_loss < best_val_loss: save_model(coma, optimizer, epoch, train_loss, val_loss, save_checkpoint_dir) best_val_loss = val_loss val_loss_history.append(val_loss) val_losses.append(best_val_loss) if opt == 'sgd': adjust_learning_rate(optimizer, lr_decay) if torch.cuda.is_available(): torch.cuda.synchronize() tbSummWriter.flush() tbSummWriter.close()
def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) for k in config.keys() : print(k, config[k]) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) if args.checkpoint_dir: checkpoint_dir = args.checkpoint_dir else: checkpoint_dir = config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) visualize = config['visualize'] output_dir = config['visual_output_dir'] if visualize is True and not output_dir: print('No visual output directory is provided. Checkpoint directory will be used to store the visual results') output_dir = checkpoint_dir if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) eval_flag = config['eval'] lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config['workers_thread'] opt = config['optimizer'] batch_size = config['batch_size'] val_losses, accs, durations = [], [], [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices(template_mesh, config['downsampling_factors']) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Loading Dataset') if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) dataset_val = ComaDataset(data_dir, dtype='val', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread) val_loader = DataLoader(dataset_val, batch_size=1, shuffle=True, num_workers=workers_thread) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=workers_thread) print('Loading model') start_epoch = 1 coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes) if opt == 'adam': optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') checkpoint_file = config['checkpoint_file'] print(checkpoint_file) if checkpoint_file: checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch_num'] coma.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #To find if this is fixed in pytorch for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) coma.to(device) if eval_flag: val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize) print('val loss', val_loss) return best_val_loss = float('inf') val_loss_history = [] from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = os.path.join('runs/ae', current_time) writer = SummaryWriter(log_dir+'-ds2_lr0.04') for epoch in range(start_epoch, total_epochs + 1): print("Training for epoch ", epoch) train_loss = train(coma, train_loader, len(dataset), optimizer, device) val_loss = evaluate(coma, output_dir, val_loader, dataset_val, template_mesh, device, epoch, visualize=visualize) writer.add_scalar('data/train_loss', train_loss, epoch) writer.add_scalar('data/val_loss', val_loss, epoch) print('epoch ', epoch,' Train loss ', train_loss, ' Val loss ', val_loss) if val_loss < best_val_loss: save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir) best_val_loss = val_loss if epoch == total_epochs or epoch % 100 == 0: save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir) val_loss_history.append(val_loss) val_losses.append(best_val_loss) if opt=='sgd': adjust_learning_rate(optimizer, lr_decay) if torch.cuda.is_available(): torch.cuda.synchronize() writer.close()
print('configuration file not specified, trying to load ' 'it from current directory', args.conf) if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices(template_mesh, config['downsampling_factors']) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Loading Dataset') if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) if args.checkpoint_dir: checkpoint_dir = args.checkpoint_dir else: checkpoint_dir = config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) visualize = config['visualize'] output_dir = config['visual_output_dir'] if visualize is True and not output_dir: print( 'No visual output directory is provided. Checkpoint directory will be used to store the visual results' ) output_dir = checkpoint_dir if not os.path.exists(output_dir): os.makedirs(output_dir) eval_flag = config['eval'] lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config['workers_thread'] opt = config['optimizer'] batch_size = config['batch_size'] val_losses, accs, durations = [], [], [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, config['downsampling_factors']) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Loading Dataset') if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) print('Loading model') start_epoch = 1 coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes) if opt == 'adam': optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') checkpoint_file = config['checkpoint_file'] print(checkpoint_file) if checkpoint_file: checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch_num'] coma.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #To find if this is fixed in pytorch for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) coma.to(device) print('making...') norm = torch.load('../processed_data/processed/sliced_norm.pt') normalize_transform.mean = norm['mean'] normalize_transform.std = norm['std'] #'0512','0901','0516','0509','0507','9305','0503','4919','4902', files = [ '0514', '0503', '0507', '0509', '0512', '0501', '0901', '1001', '4902', '4913', '4919', '9302', '9305', '12411' ] coma.eval() meshviewer = MeshViewers(shape=(1, 2)) for file in files: #mat = np.load('../Dress Dataset/'+file+'/'+file+'_pose.npz') mesh_dir = os.listdir('../processed_data/' + file + '/mesh/') latent = [] print(len(mesh_dir)) for i in tqdm(range(len(mesh_dir))): data_file = '../processed_data/' + file + '/mesh/' + str( i) + '.obj' mesh = Mesh(filename=data_file) adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo() edge_index = torch.Tensor(np.vstack( (adjacency.row, adjacency.col))).long() mesh_verts = (torch.Tensor(mesh.v) - normalize_transform.mean) / normalize_transform.std data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) data = data.to(device) with torch.no_grad(): out, feature = coma(data) latent.append(feature.cpu().detach().numpy()) # print(feature.shape) if i % 50 == 0: expected_out = data.x out = out.cpu().detach( ) * normalize_transform.std + normalize_transform.mean expected_out = expected_out.cpu().detach( ) * normalize_transform.std + normalize_transform.mean out = out.numpy() save_obj(out, template_mesh.f + 1, './vis/reconstruct_' + str(i) + '.obj') save_obj(expected_out, template_mesh.f + 1, './vis/ori_' + str(i) + '.obj') np.save('./processed/0820/' + file, latent) if torch.cuda.is_available(): torch.cuda.synchronize()
def main(checkpoint, config_path, output_dir): config = read_config(config_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, config['downsampling_factors']) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Preparing dataset') data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='test', split='sliced', split_term='sliced', pre_transform=normalize_transform) loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) print('Loading model') model = Coma(dataset, config, D_t, U_t, A_t, num_nodes) checkpoint = torch.load(checkpoint) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) model.eval() model.to(device) print('Generating latent') data = next(iter(loader)) with torch.no_grad(): data = data.to(device) x = data.x.reshape(data.num_graphs, -1, model.filters[0]) z = model.encoder(x) print('View meshes') meshviewer = MeshViewers(shape=(1, 1)) for feature_index in range(z.size(1)): j = torch.range(-4, 4, step=0.1, device=device) new_z = z.expand(j.size(0), z.size(1)).clone() new_z[:, feature_index] *= 1 + 0.3 * j with torch.no_grad(): out = model.decoder(new_z) out = out.detach().cpu() * dataset.std + dataset.mean for i in trange(out.shape[0]): mesh = Mesh(v=out[i], f=template_mesh.f) meshviewer[0][0].set_dynamic_meshes([mesh]) f = os.path.join(output_dir, 'z{}'.format(feature_index), '{:04d}.png'.format(i)) os.makedirs(os.path.dirname(f), exist_ok=True) meshviewer[0][0].save_snapshot(f, blocking=True)
def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) if args.checkpoint_dir: checkpoint_dir = args.checkpoint_dir else: checkpoint_dir = config['checkpoint_dir'] if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) visualize = config['visualize'] output_dir = config['visual_output_dir'] if visualize is True and not output_dir: print( 'No visual output directory is provided. Checkpoint directory will be used to store the visual results' ) output_dir = checkpoint_dir if not os.path.exists(output_dir): os.makedirs(output_dir) eval_flag = config['eval'] lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config['workers_thread'] opt = config['optimizer'] batch_size = config['batch_size'] val_losses, accs, durations = [], [], [] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Generating transforms') M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, config['downsampling_factors']) print(len(M)) for i in range(len(M)): print(M[i].v.shape) print('************A****************') for a in A: print(a.shape) print('************D****************') for d in D: print(d.shape) print('************U****************') for u in U: print(u.shape) D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] print('Loading Dataset') if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=workers_thread) print('Loading model') start_epoch = 1 coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes) if opt == 'adam': optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') checkpoint_file = config['checkpoint_file'] if checkpoint_file: checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch_num'] coma.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #To find if this is fixed in pytorch for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) coma.to(device) if eval_flag: val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize) print('val loss', val_loss) return best_val_loss = float('inf') val_loss_history = [] train_loss_history = [] for epoch in range(start_epoch, total_epochs + 1): print("Training for epoch ", epoch) train_loss = train(coma, train_loader, len(dataset), optimizer, device) val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize=visualize) val_loss_history.append(val_loss) train_loss_history.append(train_loss) print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ', val_loss) if val_loss < best_val_loss: save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir) best_val_loss = val_loss val_losses.append(best_val_loss) if opt == 'sgd': adjust_learning_rate(optimizer, lr_decay) if torch.cuda.is_available(): torch.cuda.synchronize() times = list(range(len(train_loss_history))) fig = plt.figure() ax = fig.add_subplot(111) ax.plot(times, train_loss_history) ax.plot(times, val_loss_history) ax.set_xlabel("iteration") ax.set_ylabel(" loss") plt.savefig(checkpoint_dir + 'result.png')
def main(args): if not os.path.exists(args.conf): print('Config not found' + args.conf) config = read_config(args.conf) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.set_num_threads(args.num_threads) if args.rep_cudnn: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print('Initializing parameters') template_file_path = config['template_fname'] template_mesh = Mesh(filename=template_file_path) if args.checkpoint_dir: checkpoint_dir = args.checkpoint_dir else: checkpoint_dir = config['checkpoint_dir'] checkpoint_dir = os.path.join(checkpoint_dir, args.modelname) print(datetime.datetime.now()) print('checkpoint_dir', checkpoint_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if args.data_dir: data_dir = args.data_dir else: data_dir = config['data_dir'] visualize = config[ 'visualize'] if args.visualize is None else args.visualize output_dir = config['visual_output_dir'] if output_dir: output_dir = os.path.join(output_dir, args.modelname) if visualize is True and not output_dir: print('No visual output directory is provided. \ Checkpoint directory will be used to store the visual results') output_dir = checkpoint_dir if not os.path.exists(output_dir): os.makedirs(output_dir) if not args.train: eval_flag = True else: eval_flag = config['eval'] if args.learning_rate: config['learning_rate'] = args.learning_rate lr = config['learning_rate'] lr_decay = config['learning_rate_decay'] weight_decay = config['weight_decay'] total_epochs = config['epoch'] workers_thread = config[ 'workers_thread'] if args.num_workers is None else args.num_workers opt = config['optimizer'] batch_size = config['batch_size'] if args.batch is None else args.batch val_losses, accs, durations = [], [], [] if args.device_idx is None: device = torch.device( "cuda:" + str(config['device_idx']) if torch.cuda.is_available() else "cpu") elif args.device_idx >= 0: device = torch.device( "cuda:" + str(args.device_idx) if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") print(config) ds_fname = os.path.join( './template/', data_dir.split('/')[-1] + '_' + args.hier_matrices + '.pkl') if not os.path.exists(ds_fname): print("Generating Transform Matrices ..") M, A, D, U = mesh_operations.generate_transform_matrices( template_mesh, config['downsampling_factors']) with open(ds_fname, 'wb') as fp: M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))] pickle.dump( { 'M_verts_faces': M_verts_faces, 'A': A, 'D': D, 'U': U }, fp) else: print("Loading Transform Matrices ..") with open(ds_fname, 'rb') as fp: downsampling_matrices = pickle.load(fp) M_verts_faces = downsampling_matrices['M_verts_faces'] M = [ Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces)) ] A = downsampling_matrices['A'] D = downsampling_matrices['D'] U = downsampling_matrices['U'] D_t = [scipy_to_torch_sparse(d).to(device) for d in D] U_t = [scipy_to_torch_sparse(u).to(device) for u in U] A_t = [scipy_to_torch_sparse(a).to(device) for a in A] num_nodes = [len(M[i].v) for i in range(len(M))] nV_ref = [] ref_mean = np.mean(M[0].v, axis=0) ref_std = np.std(M[0].v, axis=0) for i in range(len(M)): nv = 0.1 * (M[i].v - ref_mean) / ref_std nV_ref.append(nv) tV_ref = [torch.from_numpy(s).float().to(device) for s in nV_ref] print('Loading Dataset') normalize_transform = Normalize() dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) dataset_val = ComaDataset(data_dir, dtype='val', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread) val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=workers_thread) test_loader = DataLoader(dataset_test, batch_size=1 if visualize else batch_size, shuffle=False, num_workers=workers_thread) print('Loading model') start_epoch = 1 if args.modelname in {'ComaAtt'}: gcn_model = eval(args.modelname)(dataset, config, D_t, U_t, A_t, num_nodes, tV_ref) gcn_params = gcn_model.parameters() else: gcn_model = eval(args.modelname)(dataset, config, D_t, U_t, A_t, num_nodes) gcn_params = gcn_model.parameters() params = sum(p.numel() for p in gcn_model.parameters() if p.requires_grad) print("Total number of parameters is: {}".format(params)) print(gcn_model) if opt == 'adam': optimizer = torch.optim.Adam(gcn_params, lr=lr, weight_decay=weight_decay) elif opt == 'sgd': optimizer = torch.optim.SGD(gcn_params, lr=lr, weight_decay=weight_decay, momentum=0.9) else: raise Exception('No optimizer provided') if args.checkpoint_file: checkpoint_file = os.path.join(checkpoint_dir, str(args.checkpoint_file) + '.pt') else: checkpoint_file = config['checkpoint_file'] if eval_flag and not checkpoint_file: checkpoint_file = os.path.join(checkpoint_dir, 'checkpoint.pt') print(checkpoint_file) if checkpoint_file: print('Loading checkpoint file: {}.'.format(checkpoint_file)) checkpoint = torch.load(checkpoint_file, map_location=device) start_epoch = checkpoint['epoch_num'] gcn_model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) gcn_model.to(device) if eval_flag: val_loss, euclidean_loss = evaluate(gcn_model, output_dir, test_loader, dataset_test, template_mesh, device, visualize) print('val loss', val_loss) print('euclidean error is {} mm'.format(1000 * euclidean_loss)) return best_val_loss = float('inf') val_loss_history = [] for epoch in range(start_epoch, total_epochs + 1): print("Training for epoch ", epoch) train_loss = train(gcn_model, train_loader, len(dataset), optimizer, device) val_loss, _ = evaluate(gcn_model, output_dir, val_loader, dataset_val, template_mesh, device, visualize=visualize) print('epoch {}, Train loss {:.8f}, Val loss {:.8f}'.format( epoch, train_loss, val_loss)) if val_loss < best_val_loss: save_model(gcn_model, optimizer, epoch, train_loss, val_loss, checkpoint_dir) best_val_loss = val_loss val_loss_history.append(val_loss) val_losses.append(best_val_loss) if opt == 'sgd': adjust_learning_rate(optimizer, lr_decay) if epoch in args.epochs_eval or (val_loss <= best_val_loss and epoch > int(total_epochs * 3 / 4)): val_loss, euclidean_loss = evaluate(gcn_model, output_dir, test_loader, dataset_test, template_mesh, device, visualize) print('epoch {} with val loss {}'.format(epoch, val_loss)) print('euclidean error is {} mm'.format(1000 * euclidean_loss)) if torch.cuda.is_available(): torch.cuda.synchronize()