def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--manualSeed', type=int, help='manual seed') parser.add_argument('--batch_size', type=int, default=30) parser.add_argument('--epochs', type=int, default=241) parser.add_argument('--workers', type=int, default=6, help='num of workers to load data for each DataLoader') parser.add_argument('--checkpoints_dir', '-CDIR', default='experiments_deco', help='Folder where all experiments get stored') parser.add_argument( '--exp_name', '-EXP', default='exp', help='will create an exp_name folder under checkpoints_dir') parser.add_argument('--config', '-C', required=True, help='path to valid configuration file') parser.add_argument('--parallel', action='store_true', help="Multi-GPU Training") parser.add_argument( '--it_test', type=int, default=10, help='at each it_test epoch: perform test and checkpoint') parser.add_argument('--restart_from', default='', help='restart interrupted training from checkpoint') parser.add_argument( '--class_choice', default= "Airplane,Bag,Cap,Car,Chair,Guitar,Lamp,Laptop,Motorbike,Mug,Pistol,Skateboard,Table", help='Classes to train on: default is 13 classes used in PF-Net') parser.add_argument( '--data_root', default= "/home/antonioa/data/shapenetcore_partanno_segmentation_benchmark_v0") # crop params parser.add_argument('--crop_point_num', type=int, default=512, help='number of points to crop') parser.add_argument('--context_point_num', type=int, default=512, help='number of points of the frame region') parser.add_argument('--num_holes', type=int, default=1, help='number of crop_point_num holes') parser.add_argument( '--pool1_points', '-P1', type=int, default=1280, help= 'points selected at pooling layer 1, we use 1280 in all experiments') parser.add_argument( '--pool2_points', '-P2', type=int, default=512, help= 'points selected at pooling layer 2, should match crop_point_num i.e. 512' ) # parser.add_argument('--fps_centroids', '-FPS', action='store_true', help='different crop logic than pfnet') parser.add_argument( '--raw_weight', '-RW', type=float, default=1, help= 'weights the intermediate pred (frame reg.) loss, use 0 this to disable regularization.' ) args = parser.parse_args() args.fps_centroids = False # make experiment dirs args.save_dir = os.path.join(args.checkpoints_dir, args.exp_name) args.models_dir = os.path.join(args.save_dir, 'models') args.vis_dir = os.path.join(args.save_dir, 'train_visz') safe_make_dirs([ args.save_dir, args.models_dir, args.vis_dir, os.path.join(args.save_dir, 'backup_code') ]) # instantiate loggers io_logger = IOStream(os.path.join(args.save_dir, 'log.txt')) tb_logger = SummaryWriter(logdir=args.save_dir) return args, io_logger, tb_logger
def prepare_midi(midi_paths, max_num_files, output_base_dir, inst_classes, defs_dict, pgm0_is_piano=False, rerender_existing=False, band_classes_def=None, same_pgms_diff=False, separate_drums=False, zero_based_midi=False): """ Loops through a list of `midi_paths` until `max_num_files` have been flagged for synthesis. For each file flagged for synthesis, the MIDI file is copied to the output directory and each individual track split of into its own MIDI file. Each MIDI instrument track is also assigned a synthesis patch. Args: midi_paths (list): List of paths to MIDI files. max_num_files (int): Total number of files to render. output_base_dir (str): Base directory where output will be stored. inst_classes: defs_dict: pgm0_is_piano: rerender_existing: band_classes_def: same_pgms_diff: separate_drums: zero_based_midi: Returns: """ midi_files_read = 0 defs_dict = defs_dict if not zero_based_midi else make_zero_based_midi( defs_dict) srcs_by_inst = make_src_by_inst(defs_dict) inv_defs_dict = invert_defs_dict(defs_dict) for path in midi_paths: logger.info('Starting {}'.format(path)) pm = check_midi_file(path, inst_classes, pgm0_is_piano, band_classes_def, separate_drums) if not pm: continue # Okay, we're all good to continue now logger.info('({}/{}) Selected {}'.format(midi_files_read, max_num_files, path)) midi_files_read += 1 # Make a whole bunch of paths to store everything uuid = os.path.splitext(os.path.basename(path))[0] out_dir_name = 'Track{:05d}'.format(midi_files_read) output_dir = os.path.join(output_base_dir, out_dir_name) utils.safe_make_dirs(output_dir) shutil.copy(path, output_dir) os.rename(os.path.join(output_dir, uuid + '.mid'), os.path.join(output_dir, 'all_src.mid')) midi_out_dir = os.path.join(output_dir, 'MIDI') audio_out_dir = os.path.join(output_dir, 'stems') utils.safe_make_dirs(midi_out_dir) utils.safe_make_dirs(audio_out_dir) # Set up metadata metadata = { 'lmd_midi_dir': os.path.sep.join(path.split(os.path.sep)[-6:]), 'midi_dir': midi_out_dir, 'audio_dir': audio_out_dir, 'UUID': uuid, 'stems': {} } seen_pgms = {} # Loop through instruments in this MIDI file for j, inst in enumerate(pm.instruments): # Name it and figure out what instrument class this is key = 'S{:02d}'.format(j) inst_cls = utils.get_inst_class(inst_classes, inst, pgm0_is_piano) # Set up metadata metadata['stems'][key] = {} metadata['stems'][key]['inst_class'] = inst_cls metadata['stems'][key]['is_drum'] = inst.is_drum metadata['stems'][key][ 'midi_program_name'] = utils.get_inst_program_name( inst_classes, inst, pgm0_is_piano) if inst.is_drum: # Drums use this special flag, but not the program number, # so the pgm # is always set to 0. # But usually program number 0 is piano. # So we define program number 129/128 for drums to avoid collisions. program_num = 129 if not zero_based_midi else 128 else: program_num = int(inst.program) metadata['stems'][key]['program_num'] = program_num metadata['stems'][key]['midi_saved'] = False metadata['stems'][key]['audio_rendered'] = False if program_num not in inv_defs_dict or len( inv_defs_dict[program_num]) < 1: metadata['stems'][key]['plugin_name'] = 'None' logger.info( 'No instrument loaded for \'{}\' (skipping).'.format( inst_cls)) continue # if we've seen this program # before, use the previously selected patch if not same_pgms_diff and program_num in seen_pgms.keys(): selected_patch = seen_pgms[program_num] else: selected_patch = select_patch_rand(inv_defs_dict, program_num) seen_pgms[program_num] = selected_patch metadata['stems'][key]['plugin_name'] = selected_patch # Save the info we need for the next stages render_info = { 'metadata': os.path.join(output_dir, 'metadata.yaml'), 'source_key': key, 'end_time': pm.get_end_time() + 5.0 } srcs_by_inst[selected_patch].append(render_info) # Make the output path midi_out_path = os.path.join(midi_out_dir, '{}.mid'.format(key)) if os.path.exists(midi_out_path) and not rerender_existing: logger.info('Found {}. Skipping...'.format(midi_out_path)) continue # Save a midi file with just that source midi_stem = copy.copy(pm) midi_stem.name = key midi_stem.instruments = [] inst = midi_inst_rules.apply_midi_rules(inst, inst_cls) midi_stem.instruments.append(inst) midi_stem.write(midi_out_path) if os.path.isfile(midi_out_path): metadata['stems'][key]['midi_saved'] = True logger.info('Wrote {}.mid. Selected patch \'{}\''.format( key, selected_patch)) if not rerender_existing: with open(os.path.join(output_dir, 'metadata.yaml'), 'w') as f: f.write( yaml.safe_dump(metadata, default_flow_style=False, allow_unicode=True)) logger.info('Finished {}'.format(path)) if midi_files_read >= max_num_files: logger.info('Finished reading MIDI') break return srcs_by_inst
def main_worker(): opt, io, tb = get_args() start_epoch = -1 start_time = time.time() BASE_DIR = os.path.dirname( os.path.abspath(__file__)) # python script folder ckt = None if len(opt.restart_from) > 0: ckt = torch.load(opt.restart_from) start_epoch = ckt['epoch'] - 1 # load configuration from file try: with open(opt.config) as cf: config = json.load(cf) except IOError as error: print(error) # backup relevant files shutil.copy(src=os.path.abspath(__file__), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=opt.config, dst=os.path.join(opt.save_dir, 'backup_code', 'config.json.backup')) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed_all(opt.manualSeed) io.cprint(f"Arguments: {str(opt)}") io.cprint(f"Configuration: {str(config)}") pnum = config['completion_trainer']['num_points'] class_choice = opt.class_choice # datasets + loaders if len(class_choice) > 0: class_choice = ''.join(opt.class_choice.split()).split( ",") # sanitize + split(",") io.cprint("Class choice list: {}".format(str(class_choice))) else: class_choice = None # Train on all classes! (if opt.class_choice=='') tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='train') te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='test') tr_loader = torch.utils.data.DataLoader(tr_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, drop_last=True) te_loader = torch.utils.data.DataLoader(te_dataset, batch_size=64, shuffle=True, num_workers=opt.workers) num_holes = int(opt.num_holes) crop_point_num = int(opt.crop_point_num) context_point_num = int(opt.context_point_num) # io.cprint("Num holes: {}".format(num_holes)) # io.cprint("Crop points num: {}".format(crop_point_num)) # io.cprint("Context points num: {}".format(context_point_num)) # io.cprint("Pool1 num points selected: {}".format(opt.pool1_points)) # io.cprint("Pool2 num points selected: {}".format(opt.pool2_points)) """" Models """ gl_encoder = Encoder(conf=config) generator = Generator(conf=config, pool1_points=int(opt.pool1_points), pool2_points=int(opt.pool2_points)) gl_encoder.apply(weights_init_normal) # affecting only non pretrained generator.apply(weights_init_normal) # not pretrained print("Encoder: ", gl_encoder) print("Generator: ", generator) if ckt is not None: io.cprint(f"Restart Training from epoch {start_epoch}.") gl_encoder.load_state_dict(ckt['gl_encoder_state_dict']) generator.load_state_dict(ckt['generator_state_dict']) else: io.cprint("Training Completion Task...") local_fe_fn = config['completion_trainer']['checkpoint_local_enco'] global_fe_fn = config['completion_trainer']['checkpoint_global_enco'] if len(local_fe_fn) > 0: local_enco_dict = torch.load(local_fe_fn)['model_state_dict'] # refactoring pretext-trained local dgcnn encoder state dict keys local_enco_dict = remove_prefix_dict( state_dict=local_enco_dict, to_remove_str='local_encoder.') loc_load_result = gl_encoder.local_encoder.load_state_dict( local_enco_dict, strict=False) io.cprint( f"Local FE pretrained weights - loading res: {str(loc_load_result)}" ) else: # Ablation experiments only io.cprint("Local FE pretrained weights - NOT loaded", color='r') if len(global_fe_fn) > 0: global_enco_dict = torch.load(global_fe_fn, )['global_encoder'] glob_load_result = gl_encoder.global_encoder.load_state_dict( global_enco_dict, strict=True) io.cprint( f"Global FE pretrained weights - loading res: {str(glob_load_result)}", color='b') else: # Ablation experiments only io.cprint("Global FE pretrained weights - NOT loaded", color='r') io.cprint("Num GPUs: " + str(torch.cuda.device_count()) + ", Parallelism: {}".format(opt.parallel)) if opt.parallel and torch.cuda.device_count() > 1: gl_encoder = torch.nn.DataParallel(gl_encoder) generator = torch.nn.DataParallel(generator) gl_encoder.to(device) generator.to(device) # Optimizers + schedulers opt_E = torch.optim.Adam( gl_encoder.parameters(), lr=config['completion_trainer']['enco_lr'], # def: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_E = torch.optim.lr_scheduler.StepLR( opt_E, step_size=config['completion_trainer']['enco_step'], # def: 25 gamma=0.5) opt_G = torch.optim.Adam( generator.parameters(), lr=config['completion_trainer']['gen_lr'], # def: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_G = torch.optim.lr_scheduler.StepLR( opt_G, step_size=config['completion_trainer']['gen_step'], # def: 40 gamma=0.5) if ckt is not None: opt_E.load_state_dict(ckt['optimizerE_state_dict']) opt_G.load_state_dict(ckt['optimizerG_state_dict']) sched_E.load_state_dict(ckt['schedulerE_state_dict']) sched_G.load_state_dict(ckt['schedulerG_state_dict']) if not opt.fps_centroids: # 5 viewpoints to crop around - same as in PFNet centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0], [-1, 1, 0]]) else: raise NotImplementedError('experimental') centroids = None io.cprint("Training.. \n") best_test = sys.float_info.max best_ep = -1 it = 0 # global iteration counter vis_folder = None for epoch in range(start_epoch + 1, opt.epochs): start_ep_time = time.time() count = 0.0 tot_loss = 0.0 tot_fine_loss = 0.0 tot_raw_loss = 0.0 gl_encoder = gl_encoder.train() generator = generator.train() for i, data in enumerate(tr_loader, 0): it += 1 points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts, raw_gts = [], [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): # points[m]: complete shape of size (N,3) # partial: partial point cloud to complete # fine_gt: missing part ground truth # raw_gt: missing part ground truth + frame points (where frame points are points included in partial) partial, fine_gt, raw_gt = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1, "Should be no need to resample if not multiple holes case" # sampling without replacement choice = torch.randperm(partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) raw_gts.append(raw_gt) if i == 1 and epoch % opt.it_test == 0: # make some visualization vis_folder = os.path.join(opt.vis_dir, "epoch_{}".format(epoch)) safe_make_dirs([vis_folder]) print(f"ep {epoch} - Saving visualizations into: {vis_folder}") for j in range(len(partials)): np.savetxt(X=partials[j], fname=os.path.join(vis_folder, '{}_cropped.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fine_gts[j], fname=os.path.join(vis_folder, '{}_fine_gt.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=raw_gts[j], fname=os.path.join(vis_folder, '{}_raw_gt.txt'.format(j)), fmt='%.5f', delimiter=';') partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to(device) # [B, 512, 3] raw_gts = torch.stack(raw_gts).to(device) # [B, 512 + context, 3] if i == 1: # sanity check print("[dbg]: partials: ", partials.size(), ' ', partials.device) print("[dbg]: fine grained gts: ", fine_gts.size(), ' ', fine_gts.device) print("[dbg]: raw grained gts: ", raw_gts.size(), ' ', raw_gts.device) gl_encoder.zero_grad() generator.zero_grad() feat = gl_encoder(partials) fake_fine, fake_raw = generator( feat ) # pred_fine (only missing part), pred_intermediate (missing + frame) # pytorch 1.2 compiled Chamfer (C2C) dist. assert fake_fine.size() == fine_gts.size( ), "Wrong input shapes to Chamfer module" if i == 0: if fake_raw.size() != raw_gts.size(): warnings.warn( "size dismatch for: raw_pred: {}, raw_gt: {}".format( str(fake_raw.size()), str(raw_gts.size()))) # fine grained prediction + gt fake_fine = fake_fine.contiguous() fine_gts = fine_gts.contiguous() # raw prediction + gt fake_raw = fake_raw.contiguous() raw_gts = raw_gts.contiguous() dist1, dist2, _, _ = NND.nnd( fake_fine, fine_gts) # fine grained loss computation dist1_raw, dist2_raw, _, _ = NND.nnd( fake_raw, raw_gts) # raw grained loss computation # standard C2C distance loss fine_loss = 100 * (0.5 * torch.mean(dist1) + 0.5 * torch.mean(dist2)) # raw loss: missing part + frame raw_loss = 100 * (0.5 * torch.mean(dist1_raw) + 0.5 * torch.mean(dist2_raw)) loss = fine_loss + opt.raw_weight * raw_loss # missing part pred loss + α * raw reconstruction loss loss.backward() opt_E.step() opt_G.step() tot_loss += loss.item() * B tot_fine_loss += fine_loss.item() * B tot_raw_loss += raw_loss.item() * B if it % 10 == 0: io.cprint( '[%d/%d][%d/%d]: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' % (epoch, opt.epochs, i, len(tr_loader), loss.item(), fine_loss.item(), raw_loss.item())) # make visualizations if i == 1 and epoch % opt.it_test == 0: assert (vis_folder is not None and os.path.exists(vis_folder)) fake_fine = fake_fine.cpu().detach().data.numpy() fake_raw = fake_raw.cpu().detach().data.numpy() for j in range(len(fake_fine)): np.savetxt(X=fake_fine[j], fname=os.path.join( vis_folder, '{}_pred_fine.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fake_raw[j], fname=os.path.join(vis_folder, '{}_pred_raw.txt'.format(j)), fmt='%.5f', delimiter=';') sched_E.step() sched_G.step() io.cprint( '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' % (epoch, opt.epochs, tot_loss * 1.0 / count, tot_fine_loss * 1.0 / count, tot_raw_loss * 1.0 / count)) tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_interm', tot_raw_loss * 1.0 / count, epoch) if epoch % opt.it_test == 0: torch.save( { 'type_exp': 'dgccn at local encoder', 'epoch': epoch + 1, 'epoch_train_loss': tot_loss * 1.0 / count, 'epoch_train_loss_raw': tot_raw_loss * 1.0 / count, 'epoch_train_loss_fine': tot_fine_loss * 1.0 / count, 'gl_encoder_state_dict': gl_encoder.module.state_dict() if isinstance( gl_encoder, nn.DataParallel) else gl_encoder.state_dict(), 'generator_state_dict': generator.module.state_dict() if isinstance( generator, nn.DataParallel) else generator.state_dict(), 'optimizerE_state_dict': opt_E.state_dict(), 'optimizerG_state_dict': opt_G.state_dict(), 'schedulerE_state_dict': sched_E.state_dict(), 'schedulerG_state_dict': sched_G.state_dict(), }, os.path.join(opt.models_dir, 'checkpoint_' + str(epoch) + '.pth')) if epoch % opt.it_test == 0: test_cd, count = 0.0, 0.0 for i, data in enumerate(te_loader, 0): points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts = [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, _ = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling Without replacement choice = torch.randperm( partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to( device).contiguous() # [B, 512, 3] # TEST FORWARD # Considering only missing part prediction at Test Time gl_encoder.eval() generator.eval() with torch.no_grad(): feat = gl_encoder(partials) fake_fine, _ = generator(feat) fake_fine = fake_fine.contiguous() assert fake_fine.size() == fine_gts.size() dist1, dist2, _, _ = NND.nnd(fake_fine, fine_gts) cd_loss = 100 * (0.5 * torch.mean(dist1) + 0.5 * torch.mean(dist2)) test_cd += cd_loss.item() * B test_cd = test_cd * 1.0 / count io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' % (epoch, opt.epochs, test_cd), color="b") tb.add_scalar('Test/cd_loss', test_cd, epoch) is_best = test_cd < best_test best_test = min(best_test, test_cd) if is_best: # best model case best_ep = epoch io.cprint("New best test %.5f at epoch %d" % (best_test, best_ep)) shutil.copyfile(src=os.path.join( opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'), dst=os.path.join(opt.models_dir, 'best_model.pth')) io.cprint( '[%d/%d] Epoch time: %s' % (epoch, num_epochs, time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time)))) # Script ends hours, rem = divmod(time.time() - start_time, 3600) minutes, seconds = divmod(rem, 60) io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds)) io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
if classname.find("Conv2d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("Conv1d") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm1d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) args = parse_args() exp_dir = os.path.join(args.checkpoints_dir, args.exp_name + '_' + str(int(time.time()))) tb_dir, models_dir = osp.join(exp_dir, "tb_logs"), osp.join(exp_dir, "models") safe_make_dirs([tb_dir, models_dir]) io = IOStream(osp.join(exp_dir, "log.txt")) io.cprint(f"Arguments: {str(args)} \n") tb_writer = SummaryWriter(logdir=tb_dir) centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0], [-1, 1, 0]]) # same as PFNet if args.num_positive_samples > 2: criterion = SupConLoss(temperature=args.temp, base_temperature=1, contrast_mode='all') else: criterion = SimCLRLoss(temperature=args.temp) io.cprint("Contrastive learning params: ") io.cprint(f"criterion: {str(criterion)}") io.cprint(f"num positive samples: {args.num_positive_samples}") io.cprint(f"centroids cropping: {str(centroids)}")
def main(opt): exp_dir = osp.join(opt.checkpoints_dir, opt.exp_name) tb_dir, models_dir = osp.join(exp_dir, "tb_logs"), osp.join(exp_dir, "models") safe_make_dirs([tb_dir, models_dir]) io = IOStream(osp.join(exp_dir, "log.txt")) tb_logger = SummaryWriter(logdir=tb_dir) assert os.path.exists(opt.config), "wrong config path" with open(opt.config) as cf: config = json.load(cf) io.cprint(f"Arguments: {str(opt)}") io.cprint(f"Config: {str(config)} \n") if len(opt.class_choice) > 0: class_choice = ''.join(opt.class_choice.split()).split( ",") # sanitize + split(",") io.cprint("Class choice: {}".format(str(class_choice))) else: class_choice = None train_dataset = PretextDataset(root=opt.data_root, task='denoise', class_choice=class_choice, npoints=config["num_points"], split='train', normalize=True, noise_mean=config["noise_mean"], noise_std=config["noise_std"]) test_dataset = PretextDataset(root=opt.data_root, task='denoise', class_choice=class_choice, npoints=config["num_points"], split='test', normalize=True, noise_mean=config["noise_mean"], noise_std=config["noise_std"]) train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, drop_last=True, num_workers=opt.workers) test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, drop_last=False, num_workers=opt.workers) criterion = nn.MSELoss() # loss function for denoising device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # MODEL model = GPDLocalFE(config) if opt.parallel: io.cprint( f"DataParallel training with {torch.cuda.device_count()} GPUs") model = nn.DataParallel(model) model = model.to(device) io.cprint(f'model: {str(model)}') # OPTIMIZER + SCHEDULER optimizer = torch.optim.Adam(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) train_start = time.time() for epoch in range(opt.epochs): # TRAIN # we compute both MSE and Chamfer Distance distances between the cleaned pointcloud and the clean GT, # where cleaned = model(noised) # .. Anyway MSE is used as loss function and Chamfer Distance is just an additional metric ep_start = time.time() train_mse, train_cd = train_one_epoch(train_loader, model, optimizer, criterion, device) train_time = time.strftime("%M:%S", time.gmtime(time.time() - ep_start)) io.cprint("Train %d, time: %s, MSE (loss): %.6f, CD (dist): %.6f" % (epoch, train_time, train_mse, train_cd)) tb_logger.add_scalar("Train/MSE_loss", train_mse, epoch) tb_logger.add_scalar("Train/CD_dist", train_cd, epoch) # TEST mse_test, cd_test = test(test_loader, model, criterion, device) io.cprint("Test %d, MSE (loss): %.6f, CD (dist): %.6f" % (epoch, mse_test, cd_test)) tb_logger.add_scalar("Test/MSE", mse_test, epoch) tb_logger.add_scalar("Test/CD", cd_test, epoch) # LR SCHEDULING scheduler.step() if epoch % 10 == 0: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict() if not opt.parallel else model.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, osp.join(models_dir, "local_denoise_{}.pth".format(epoch))) hours, rem = divmod(time.time() - train_start, 3600) minutes, seconds = divmod(rem, 60) io.cprint("Training ended in {:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds))
def main_worker(): opt, io, tb = get_args() start_epoch = -1 start_time = time.time() ckt = None if len(opt.restart_from) > 0: ckt = torch.load(opt.restart_from) start_epoch = ckt['epoch'] - 1 # load configuration from file try: with open(opt.config) as cf: config = json.load(cf) except IOError as error: print(error) # backup relevant files shutil.copy(src=os.path.abspath(__file__), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'models', 'model_deco.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=os.path.join(BASE_DIR, 'shape_utils.py'), dst=os.path.join(opt.save_dir, 'backup_code')) shutil.copy(src=opt.config, dst=os.path.join(opt.save_dir, 'backup_code', 'config.json.backup')) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed_all(opt.manualSeed) io.cprint(f"Arguments: {str(opt)}") io.cprint(f"Configuration: {str(config)}") pnum = config['completion_trainer'][ 'num_points'] # number of points of complete pointcloud class_choice = opt.class_choice # config['completion_trainer']['class_choice'] # datasets + loaders if len(class_choice) > 0: class_choice = ''.join(opt.class_choice.split()).split( ",") # sanitize + split(",") io.cprint("Class choice list: {}".format(str(class_choice))) else: class_choice = None # training on all snpart classes tr_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='train') te_dataset = shapenet_part_loader.PartDataset(root=opt.data_root, classification=True, class_choice=class_choice, npoints=pnum, split='test') tr_loader = torch.utils.data.DataLoader(tr_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, drop_last=True) te_loader = torch.utils.data.DataLoader(te_dataset, batch_size=64, shuffle=True, num_workers=opt.workers) num_holes = int(opt.num_holes) crop_point_num = int(opt.crop_point_num) context_point_num = int(opt.context_point_num) # io.cprint(f"Completion Setting:\n num classes {len(tr_dataset.cat.keys())}, num holes: {num_holes}, " # f"crop point num: {crop_point_num}, frame/context point num: {context_point_num},\n" # f"num points at pool1: {opt.pool1_points}, num points at pool2: {opt.pool2_points} ") # Models gl_encoder = Encoder(conf=config) generator = Generator(conf=config, pool1_points=int(opt.pool1_points), pool2_points=int(opt.pool2_points)) gl_encoder.apply( weights_init_normal) # affecting only non pretrained layers generator.apply(weights_init_normal) print("Encoder: ", gl_encoder) print("Generator: ", generator) if ckt is not None: # resuming training from intermediate checkpoint # restoring both encoder and generator state io.cprint(f"Restart Training from epoch {start_epoch}.") gl_encoder.load_state_dict(ckt['gl_encoder_state_dict']) generator.load_state_dict(ckt['generator_state_dict']) io.cprint("Whole model loaded from {}\n".format(opt.restart_from)) else: # training the completion model # load local and global encoder pretrained (ssl pretexts) weights io.cprint("Training Completion Task...") local_fe_fn = config['completion_trainer']['checkpoint_local_enco'] global_fe_fn = config['completion_trainer']['checkpoint_global_enco'] if len(local_fe_fn) > 0: local_enco_dict = torch.load(local_fe_fn, )['model_state_dict'] loc_load_result = gl_encoder.local_encoder.load_state_dict( local_enco_dict, strict=False) io.cprint( f"Local FE pretrained weights - loading res: {str(loc_load_result)}" ) else: # Ablation experiments only io.cprint("Local FE pretrained weights - NOT loaded", color='r') if len(global_fe_fn) > 0: global_enco_dict = torch.load(global_fe_fn, )['global_encoder'] glob_load_result = gl_encoder.global_encoder.load_state_dict( global_enco_dict, strict=True) io.cprint( f"Global FE pretrained weights - loading res: {str(glob_load_result)}", color='b') else: # Ablation experiments only io.cprint("Global FE pretrained weights - NOT loaded", color='r') io.cprint("Num GPUs: " + str(torch.cuda.device_count()) + ", Parallelism: {}".format(opt.parallel)) if opt.parallel: # TODO: implement DistributedDataParallel training assert torch.cuda.device_count() > 1 gl_encoder = torch.nn.DataParallel(gl_encoder) generator = torch.nn.DataParallel(generator) gl_encoder.to(device) generator.to(device) # Optimizers + schedulers opt_E = torch.optim.Adam( gl_encoder.parameters(), lr=config['completion_trainer']['enco_lr'], # default is: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_E = torch.optim.lr_scheduler.StepLR( opt_E, step_size=config['completion_trainer']['enco_step'], # default is: 25 gamma=0.5) opt_G = torch.optim.Adam( generator.parameters(), lr=config['completion_trainer']['gen_lr'], # default is: 10e-4 betas=(0.9, 0.999), eps=1e-05, weight_decay=0.001) sched_G = torch.optim.lr_scheduler.StepLR( opt_G, step_size=config['completion_trainer']['gen_step'], # default is: 40 gamma=0.5) if ckt is not None: # resuming training from intermediate checkpoint # restore optimizers state opt_E.load_state_dict(ckt['optimizerE_state_dict']) opt_G.load_state_dict(ckt['optimizerG_state_dict']) sched_E.load_state_dict(ckt['schedulerE_state_dict']) sched_G.load_state_dict(ckt['schedulerG_state_dict']) # crop centroids if not opt.fps_centroids: # 5 viewpoints to crop around - same crop procedure of PFNet - main paper centroids = np.asarray([[1, 0, 0], [0, 0, 1], [1, 0, 1], [-1, 0, 0], [-1, 1, 0]]) else: raise NotImplementedError('experimental') centroids = None io.cprint('Centroids: ' + str(centroids)) # training loop io.cprint("Training.. \n") best_test = sys.float_info.max best_ep, glob_it = -1, 0 vis_folder = None for epoch in range(start_epoch + 1, opt.epochs): start_ep_time = time.time() count = 0.0 tot_loss = 0.0 tot_fine_loss = 0.0 tot_interm_loss = 0.0 gl_encoder = gl_encoder.train() generator = generator.train() for i, data in enumerate(tr_loader, 0): glob_it += 1 points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts, interm_gts = [], [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, interm_gt = crop_shape( points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling without replacement choice = torch.randperm(partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) interm_gts.append(interm_gt) if i == 1 and epoch % opt.it_test == 0: # make some visualization vis_folder = os.path.join(opt.vis_dir, "epoch_{}".format(epoch)) safe_make_dirs([vis_folder]) print(f"ep {epoch} - Saving visualizations into: {vis_folder}") for j in range(len(partials)): np.savetxt(X=partials[j], fname=os.path.join(vis_folder, '{}_partial.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=fine_gts[j], fname=os.path.join(vis_folder, '{}_fine_gt.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=interm_gts[j], fname=os.path.join( vis_folder, '{}_interm_gt.txt'.format(j)), fmt='%.5f', delimiter=';') partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to(device) # [B, 512, 3] interm_gts = torch.stack(interm_gts).to(device) # [B, 1024, 3] gl_encoder.zero_grad() generator.zero_grad() feat = gl_encoder(partials) pred_fine, pred_raw = generator(feat) # pytorch 1.2 compiled Chamfer (C2C) dist. assert pred_fine.size() == fine_gts.size() pred_fine, pred_raw = pred_fine.contiguous(), pred_raw.contiguous() fine_gts, interm_gts = fine_gts.contiguous( ), interm_gts.contiguous() dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts) # missing part pred loss dist1_raw, dist2_raw, _, _ = NND.nnd( pred_raw, interm_gts) # intermediate pred loss fine_loss = 50 * (torch.mean(dist1) + torch.mean(dist2) ) # chamfer is weighted by 100 interm_loss = 50 * (torch.mean(dist1_raw) + torch.mean(dist2_raw)) loss = fine_loss + opt.raw_weight * interm_loss loss.backward() opt_E.step() opt_G.step() tot_loss += loss.item() * B tot_fine_loss += fine_loss.item() * B tot_interm_loss += interm_loss.item() * B if glob_it % 10 == 0: header = "[%d/%d][%d/%d]" % (epoch, opt.epochs, i, len(tr_loader)) io.cprint('%s: loss: %.4f, fine CD: %.4f, interm. CD: %.4f' % (header, loss.item(), fine_loss.item(), interm_loss.item())) # make visualizations if i == 1 and epoch % opt.it_test == 0: assert (vis_folder is not None and os.path.exists(vis_folder)) pred_fine = pred_fine.cpu().detach().data.numpy() pred_raw = pred_raw.cpu().detach().data.numpy() for j in range(len(pred_fine)): np.savetxt(X=pred_fine[j], fname=os.path.join( vis_folder, '{}_pred_fine.txt'.format(j)), fmt='%.5f', delimiter=';') np.savetxt(X=pred_raw[j], fname=os.path.join(vis_folder, '{}_pred_raw.txt'.format(j)), fmt='%.5f', delimiter=';') sched_E.step() sched_G.step() io.cprint( '[%d/%d] Ep Train - loss: %.5f, fine cd: %.5f, interm. cd: %.5f' % (epoch, opt.epochs, tot_loss * 1.0 / count, tot_fine_loss * 1.0 / count, tot_interm_loss * 1.0 / count)) tb.add_scalar('Train/tot_loss', tot_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_fine', tot_fine_loss * 1.0 / count, epoch) tb.add_scalar('Train/cd_interm', tot_interm_loss * 1.0 / count, epoch) if epoch % opt.it_test == 0: torch.save( { 'epoch': epoch + 1, 'epoch_train_loss': tot_loss * 1.0 / count, 'epoch_train_loss_raw': tot_interm_loss * 1.0 / count, 'epoch_train_loss_fine': tot_fine_loss * 1.0 / count, 'gl_encoder_state_dict': gl_encoder.module.state_dict() if isinstance( gl_encoder, nn.DataParallel) else gl_encoder.state_dict(), 'generator_state_dict': generator.module.state_dict() if isinstance( generator, nn.DataParallel) else generator.state_dict(), 'optimizerE_state_dict': opt_E.state_dict(), 'optimizerG_state_dict': opt_G.state_dict(), 'schedulerE_state_dict': sched_E.state_dict(), 'schedulerG_state_dict': sched_G.state_dict(), }, os.path.join(opt.models_dir, 'checkpoint_' + str(epoch) + '.pth')) if epoch % opt.it_test == 0: test_cd, count = 0.0, 0.0 for i, data in enumerate(te_loader, 0): points, _ = data B, N, dim = points.size() count += B partials = [] fine_gts = [] N_partial_points = N - (crop_point_num * num_holes) for m in range(B): partial, fine_gt, _ = crop_shape(points[m], centroids=centroids, scales=[ crop_point_num, (crop_point_num + context_point_num) ], n_c=num_holes) if partial.size(0) > N_partial_points: assert num_holes > 1 # sampling Without replacement choice = torch.randperm( partial.size(0))[:N_partial_points] partial = partial[choice] partials.append(partial) fine_gts.append(fine_gt) partials = torch.stack(partials).to(device).permute( 0, 2, 1) # [B, 3, N-512] fine_gts = torch.stack(fine_gts).to( device).contiguous() # [B, 512, 3] # TEST FORWARD # Considering only missing part prediction at Test Time gl_encoder.eval() generator.eval() with torch.no_grad(): feat = gl_encoder(partials) pred_fine, _ = generator(feat) pred_fine = pred_fine.contiguous() assert pred_fine.size() == fine_gts.size() dist1, dist2, _, _ = NND.nnd(pred_fine, fine_gts) cd_loss = 50 * (torch.mean(dist1) + torch.mean(dist2)) test_cd += cd_loss.item() * B test_cd = test_cd * 1.0 / count io.cprint('Ep Test [%d/%d] - cd loss: %.5f ' % (epoch, opt.epochs, test_cd), color="b") tb.add_scalar('Test/cd_loss', test_cd, epoch) is_best = test_cd < best_test best_test = min(best_test, test_cd) if is_best: # best model case best_ep = epoch io.cprint("New best test %.5f at epoch %d" % (best_test, best_ep)) shutil.copyfile(src=os.path.join( opt.models_dir, 'checkpoint_' + str(epoch) + '.pth'), dst=os.path.join(opt.models_dir, 'best_model.pth')) io.cprint( '[%d/%d] Epoch time: %s' % (epoch, opt.epochs, time.strftime("%M:%S", time.gmtime(time.time() - start_ep_time)))) # Script ends hours, rem = divmod(time.time() - start_time, 3600) minutes, seconds = divmod(rem, 60) io.cprint("### Training ended in {:0>2}:{:0>2}:{:05.2f}".format( int(hours), int(minutes), seconds)) io.cprint("### Best val %.6f at epoch %d" % (best_test, best_ep))
def setup(rl_setting, device, _run, _log, log, seed, cuda): """ Do everything required to set up the experiment: - Create working dir - Set's cuda seed (numpy is set by sacred) - Set and configure logger - Create n_e environments - Create model - Create 'RolloutStorage': A helper class to save rewards and compute the advantage loss - Creates and initialises current_memory, a dictionary of (for each of the n_e environment): - past observation - past latent state - past action - past reward This is used as input to the model to compute the next action. Warning: It is assumed that visual environments have pixel values [0,255]. Args: All args are automatically provided by sacred by passing the equally named configuration variables that are either defined in the yaml files or the command line. Returns: id_temp_dir (str): The newly created working directory envs: Vector of environments actor_critic: The model rollouts: A helper class (RolloutStorage) to store rewards and compute TD errors current_memory: Dictionary to keep track of current obs, actions, latent states and rewards """ # Create working dir id_tmp_dir = "{}/{}/".format(log['tmp_dir'], _run._id) utils.safe_make_dirs(id_tmp_dir) np.set_printoptions(precision=2) tf.set_random_seed(seed) #torch.manual_seed(seed) #if cuda: # torch.cuda.manual_seed(seed) # Forgot why I need this? # os.environ['OMP_NUM_THREADS'] = '1' logger = logging.getLogger() if _run.debug or _run.pdb: logger.setLevel(logging.DEBUG) envs = register_and_create_Envs(id_tmp_dir) actor_critic = create_model(envs) obs_shape = envs.observation_space.shape obs_shape = (obs_shape[0], *obs_shape[1:]) rollouts = RolloutStorage(rl_setting['num_steps'], rl_setting['num_processes'], obs_shape, envs.action_space) current_obs = torch.zeros(rl_setting['num_processes'], *obs_shape) obs = envs.reset() if not actor_critic.observation_type == 'fc': obs = obs / 255. current_obs = torch.from_numpy(obs).float() # init_states = Variable(torch.zeros(rl_setting['num_processes'], actor_critic.state_size)) init_states = actor_critic.new_latent_state() init_rewards = tf.zeros([rl_setting['num_processes'], 1]) if envs.action_space.__class__.__name__ == "Discrete": action_shape = 1 else: action_shape = envs.action_space.shape[0] init_actions = tf.zeros(rl_setting['num_processes'], action_shape) init_states = init_states.to(device) init_actions = init_actions.to(device) current_obs = current_obs.to(device) init_rewards = init_rewards.to(device) actor_critic.to(device) rollouts.to(device) current_memory = { 'current_obs': current_obs, 'states': init_states, 'oneHotActions': utils.toOneHot( envs.action_space, init_actions), 'rewards': init_rewards } return id_tmp_dir, envs, actor_critic, rollouts, current_memory