def train(conf): # load network model models = utils.get_model_module(conf.model_version) # check if training run already exists. If so, delete it. if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \ os.path.exists(os.path.join(conf.model_path, conf.exp_name)): print(f'A Training run named {conf.exp_name} already exists!!') sys.exit() # response = input('A training run named "%s" already exists, overwrite? (y/n) ' % (conf.exp_name)) # if response != 'y': # sys.exit() if os.path.exists(os.path.join(conf.log_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.log_path, conf.exp_name)) if os.path.exists(os.path.join(conf.model_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.model_path, conf.exp_name)) # create directories for this run os.makedirs(os.path.join(conf.model_path, conf.exp_name)) os.makedirs(os.path.join(conf.log_path, conf.exp_name)) # file log flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w') # set training device device = torch.device(conf.device) print(f'Using device: {conf.device}') flog.write(f'Using device: {conf.device}\n') # log the object category information print(f'Object Category: {conf.category}') flog.write(f'Object Category: {conf.category}\n') # control randomness if conf.seed < 0: conf.seed = random.randint(1, 10000) print("Random Seed: %d" % (conf.seed)) flog.write(f'Random Seed: {conf.seed}\n') random.seed(conf.seed) np.random.seed(conf.seed) torch.manual_seed(conf.seed) # save config torch.save(conf, os.path.join(conf.model_path, conf.exp_name, 'conf.pth')) print('\n\n Configuration:\n', vars(conf), '\n\n') # create models encoder = models.RecursiveEncoder(conf, variational= not conf.non_variational, probabilistic=not conf.non_probabilistic) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # create optimizers encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr) decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr) optimizers = [encoder_opt, decoder_opt] optimizer_names = ['encoder', 'decoder'] # learning rate scheduler if conf.scheduler == 'StepLR': encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) elif conf.scheduler == 'ReduceLROnPlateau': encoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_opt, mode='min', \ factor=conf.lr_decay_factor, patience=2) decoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(decoder_opt, mode='min',\ factor=conf.lr_decay_factor, patience=2) # create training and validation datasets and data loaders data_features = ['uxid', 'object'] # train_dataset = PartNetDataset(conf.data_path, conf.train_dataset, data_features, \ # load_geo=conf.load_geo) # valdt_dataset = PartNetDataset(conf.data_path, conf.val_dataset, data_features, \ # load_geo=conf.load_geo) DatasetClass = globals()[config.DatasetClass] print('Using dataset:', DatasetClass) train_dataset = DatasetClass(conf.data_path, conf.train_dataset, ['uxid', 'object'], is_train=True, permute=(conf.permutations > 1), n_permutes=(conf.permutations) ) valdt_dataset = DatasetClass(conf.data_path, config.val_dataset, ['uxid', 'object'], is_train=False, permute=False, n_permutes=1) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=lambda x: list(zip(*x))) valdt_dataloader = torch.utils.data.DataLoader(valdt_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=lambda x: list(zip(*x))) # create logs if not conf.no_console_log: #header = ' Time Epoch Dataset Iteration Progress(%) LR BoxLoss StructLoss EdgeExists KLDivLoss SymLoss AdjLoss AnchorLoss TotalLoss' header = ' Time Epoch Dataset Iteration Progress(%) LR BoxLoss LeafLoss NodeExists SemLoss ChildCountLoss KLDivLoss TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch from tensorboardX import SummaryWriter train_writer = SummaryWriter(os.path.join(conf.log_path, conf.exp_name, 'train')) valdt_writer = SummaryWriter(os.path.join(conf.log_path, conf.exp_name, 'val')) # send parameters to device for m in models: m.to(device) for o in optimizers: utils.optimizer_to_device(o, device) # start training print("Starting training ...... ") flog.write('Starting training ......\n') start_time = time.time() last_checkpoint_step = None last_train_console_log_step, last_valdt_console_log_step = None, None train_num_batch, valdt_num_batch = len(train_dataloader), len(valdt_dataloader) # train for every epoch for epoch in range(conf.epochs): global train_stats global valdt_stats train_stats = Statistics() valdt_stats = Statistics() if not conf.no_console_log: print(f'training run {conf.exp_name}') flog.write(f'training run {conf.exp_name}\n') print(header) flog.write(header+'\n') train_batches = enumerate(train_dataloader, 0) valdt_batches = enumerate(valdt_dataloader, 0) train_fraction_done, valdt_fraction_done = 0.0, 0.0 valdt_batch_ind = -1 # train for every batch for train_batch_ind, batch in train_batches: #tic_compl_batch = time.time() train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # set models to training mode for m in models: m.train() # forward pass (including logging) total_loss = forward( batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=False, step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # optimize one step encoder_opt.zero_grad() decoder_opt.zero_grad() total_loss.backward() encoder_opt.step() decoder_opt.step() train_stats.add('total_loss', float(total_loss.item()), len(batch[0]) ) del total_loss # save checkpoint with torch.no_grad(): if last_checkpoint_step is None or \ train_step - last_checkpoint_step >= conf.checkpoint_interval: print("Saving checkpoint ...... ", end='', flush=True) flog.write("Saving checkpoint ...... ") utils.save_checkpoint( models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names) print("DONE") flog.write("DONE\n") last_checkpoint_step = train_step # validate one batch # while valdt_fraction_done <= train_fraction_done and valdt_batch_ind+1 < valdt_num_batch: # valdt_batch_ind, batch = next(valdt_batches) # valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch # valdt_step = (epoch + valdt_fraction_done) * train_num_batch - 1 # log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \ # valdt_step - last_valdt_console_log_step >= conf.console_log_interval) # if log_console: # last_valdt_console_log_step = valdt_step # # set models to evaluation mode # for m in models: # m.eval() # with torch.no_grad(): # # forward pass (including logging) # __ = forward( # batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, # is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time, # log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer, # lr=encoder_opt.param_groups[0]['lr'], flog=flog) # Validate in every batch with torch.no_grad(): for valdt_batch_ind , batch in valdt_batches: valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch valdt_step = epoch * valdt_num_batch + valdt_batch_ind log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \ valdt_step - last_valdt_console_log_step >= conf.console_log_interval) if log_console: last_valdt_console_log_step = valdt_step # set models to training mode for m in models: m.eval() # forward pass (including logging) total_loss = forward( batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) valdt_stats.add('total_loss', float(total_loss.item()), len(batch[0]) ) valid_loss = valdt_stats.mean(conf.metric) if conf.scheduler == 'StepLR': encoder_scheduler.step() decoder_scheduler.step() elif conf.scheduler == 'ReduceLROnPlateau': encoder_scheduler.step(valid_loss) decoder_scheduler.step(valid_loss) #print(f'1 complete batch update, Elsped time: {time.time()-tic_compl_batch:.2f}') # save the final models print("Saving final checkpoint ...... ", end='', flush=True) flog.write("Saving final checkpoint ...... ") utils.save_checkpoint( models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names) print("DONE") flog.write("DONE\n") flog.close()
def train(conf): # load network model models = utils.get_model_module(conf.model_version) # check if training run already exists. If so, delete it. if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \ os.path.exists(os.path.join(conf.model_path, conf.exp_name)): response = input( 'A training run named "%s" already exists, overwrite? (y/n) ' % (conf.exp_name)) if response != 'y': sys.exit() if os.path.exists(os.path.join(conf.log_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.log_path, conf.exp_name)) if os.path.exists(os.path.join(conf.model_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.model_path, conf.exp_name)) # create directories for this run os.makedirs(os.path.join(conf.model_path, conf.exp_name)) os.makedirs(os.path.join(conf.log_path, conf.exp_name)) # file log flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w') # set training device device = torch.device(conf.device) print(f'Using device: {conf.device}') flog.write(f'Using device: {conf.device}\n') # log the object category information print(f'Object Category: {conf.category}') flog.write(f'Object Category: {conf.category}\n') # control randomness if conf.seed < 0: conf.seed = random.randint(1, 10000) print("Random Seed: %d" % (conf.seed)) flog.write(f'Random Seed: {conf.seed}\n') random.seed(conf.seed) np.random.seed(conf.seed) torch.manual_seed(conf.seed) # create models encoder = models.PartEncoder(feat_len=conf.geo_feat_size, probabilistic=not conf.non_variational) decoder = models.PartDecoder(feat_len=conf.geo_feat_size, num_point=conf.num_point) models = [encoder, decoder] model_names = ['part_pc_encoder', 'part_pc_decoder'] # create optimizers encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr, weight_decay=conf.weight_decay) decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr, weight_decay=conf.weight_decay) optimizers = [encoder_opt, decoder_opt] optimizer_names = ['part_pc_encoder', 'part_pc_decoder'] # learning rate scheduler encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) # create training and validation datasets and data loaders train_dataset = PartNetGeoDataset(conf.data_path, conf.train_dataset, use_local_frame=conf.use_local_frame) valdt_dataset = PartNetGeoDataset(conf.data_path, conf.val_dataset, use_local_frame=conf.use_local_frame) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=utils.collate_feats) valdt_dataloader = torch.utils.data.DataLoader(valdt_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=utils.collate_feats) # create logs if not conf.no_console_log: header = ' Time Epoch Dataset Iteration Progress(%) LR ReconLoss KLDivLoss TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch from tensorboardX import SummaryWriter train_writer = SummaryWriter( os.path.join(conf.log_path, conf.exp_name, 'train')) valdt_writer = SummaryWriter( os.path.join(conf.log_path, conf.exp_name, 'val')) # save config torch.save(conf, os.path.join(conf.model_path, conf.exp_name, 'conf.pth')) # send parameters to device for m in models: m.to(device) for o in optimizers: utils.optimizer_to_device(o, device) # start training print("Starting training ...... ") flog.write('Starting training ......\n') start_time = time.time() last_checkpoint_step = None last_train_console_log_step, last_valdt_console_log_step = None, None train_num_batch, valdt_num_batch = len(train_dataloader), len( valdt_dataloader) # train for every epoch for epoch in range(conf.epochs): if not conf.no_console_log: print(f'training run {conf.exp_name}') flog.write(f'training run {conf.exp_name}\n') print(header) flog.write(header + '\n') train_batches = enumerate(train_dataloader, 0) valdt_batches = enumerate(valdt_dataloader, 0) train_fraction_done, valdt_fraction_done = 0.0, 0.0 valdt_batch_ind = -1 # train for every batch for train_batch_ind, batch in train_batches: train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # set models to training mode for m in models: m.train() # forward pass (including logging) total_loss = forward(batch=batch, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=False, step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # optimize one step encoder_scheduler.step() decoder_scheduler.step() encoder_opt.zero_grad() decoder_opt.zero_grad() total_loss.backward() encoder_opt.step() decoder_opt.step() # save checkpoint with torch.no_grad(): if last_checkpoint_step is None or \ train_step - last_checkpoint_step >= conf.checkpoint_interval: print("Saving checkpoint ...... ", end='', flush=True) flog.write("Saving checkpoint ...... ") utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join( conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names) print("DONE") flog.write("DONE\n") last_checkpoint_step = train_step # validate one batch while valdt_fraction_done <= train_fraction_done and valdt_batch_ind + 1 < valdt_num_batch: valdt_batch_ind, batch = next(valdt_batches) valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch valdt_step = (epoch + valdt_fraction_done) * train_num_batch - 1 log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \ valdt_step - last_valdt_console_log_step >= conf.console_log_interval) if log_console: last_valdt_console_log_step = valdt_step # set models to evaluation mode for m in models: m.eval() with torch.no_grad(): # forward pass (including logging) __ = forward(batch=batch, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # save the final models print("Saving final checkpoint ...... ", end='', flush=True) flog.write('Saving final checkpoint ...... ') utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names) print("DONE") flog.write("DONE\n") flog.close()
def train(conf): # create training and validation datasets and data loaders data_features = ['img', 'pts', 'ins_one_hot' , 'box_size', 'total_parts_cnt' , 'similar_parts_cnt', 'mask' ,'shape_id', 'view_id'] train_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="train", \ max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine) utils.printout(conf.flog, str(train_dataset)) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \ num_workers=conf.num_workers, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn) val_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="val", \ max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine) utils.printout(conf.flog, str(val_dataset)) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \ num_workers=0, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn) # load network model model_def = utils.get_model_module(conf.model_version) # create models network = model_def.Network(conf, train_dataset.get_part_count()) utils.printout(conf.flog, '\n' + str(network) + '\n') models = [network] model_names = ['network'] # create optimizers network_opt = torch.optim.Adam(network.parameters(), lr=conf.lr, weight_decay=conf.weight_decay) optimizers = [network_opt] optimizer_names = ['network_opt'] # learning rate scheduler network_lr_scheduler = torch.optim.lr_scheduler.StepLR(network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) # create logs if not conf.no_console_log: header = ' Time Epoch Dataset Iteration Progress(%) LR CenterLoss QuatLoss TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch from tensorboardX import SummaryWriter train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train')) val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val')) # send parameters to device for m in models: m.to(conf.device) for o in optimizers: utils.optimizer_to_device(o, conf.device) # start training start_time = time.time() last_checkpoint_step = None last_train_console_log_step, last_val_console_log_step = None, None train_num_batch = len(train_dataloader) val_num_batch = len(val_dataloader) # train for every epoch for epoch in range(conf.epochs): if not conf.no_console_log: utils.printout(conf.flog, f'training run {conf.exp_name}') utils.printout(conf.flog, header) train_batches = enumerate(train_dataloader, 0) val_batches = enumerate(val_dataloader, 0) train_fraction_done = 0.0 val_fraction_done = 0.0 val_batch_ind = -1 # train for every batch for train_batch_ind, batch in train_batches: train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # set models to training mode for m in models: m.train() # forward pass (including logging) total_loss = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \ step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \ log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr']) if total_loss is not None: # optimize one step network_lr_scheduler.step() network_opt.zero_grad() total_loss.backward() network_opt.step() # save checkpoint with torch.no_grad(): if last_checkpoint_step is None or train_step - last_checkpoint_step >= conf.checkpoint_interval: utils.printout(conf.flog, 'Saving checkpoint ...... ') utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \ epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names) utils.printout(conf.flog, 'DONE') last_checkpoint_step = train_step # validate one batch while val_fraction_done <= train_fraction_done and val_batch_ind+1 < val_num_batch: val_batch_ind, val_batch = next(val_batches) val_fraction_done = (val_batch_ind + 1) / val_num_batch val_step = (epoch + val_fraction_done) * train_num_batch - 1 log_console = not conf.no_console_log and (last_val_console_log_step is None or \ val_step - last_val_console_log_step >= conf.console_log_interval) if log_console: last_val_console_log_step = val_step # set models to evaluation mode for m in models: m.eval() with torch.no_grad(): # forward pass (including logging) __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \ step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \ log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr']) # save the final models utils.printout(conf.flog, 'Saving final checkpoint ...... ') utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \ epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names) utils.printout(conf.flog, 'DONE')
def train(conf): # load network model models = utils.get_model_module(conf.model_version) # check if training run already exists. If so, delete it. if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \ os.path.exists(os.path.join(conf.ckpt_path, conf.exp_name)): response = input( 'A training run named "%s" already exists, overwrite? (y/n) ' % (conf.exp_name)) if response != 'y': sys.exit() if os.path.exists(os.path.join(conf.log_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.log_path, conf.exp_name)) if os.path.exists(os.path.join(conf.ckpt_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.ckpt_path, conf.exp_name)) # create directories for this run os.makedirs(os.path.join(conf.ckpt_path, conf.exp_name)) os.makedirs(os.path.join(conf.log_path, conf.exp_name)) # file log flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w') # backup python files used for this training os.system('cp config.py data.py %s.py %s %s' % (conf.model_version, __file__, os.path.join(conf.log_path, conf.exp_name))) # set training device device = torch.device(conf.device) print(f'Using device: {conf.device}') flog.write(f'Using device: {conf.device}\n') # log the object category information print(f'Object Category: {conf.category}') flog.write(f'Object Category: {conf.category}\n') # control randomness if conf.seed < 0: conf.seed = random.randint(1, 10000) print("Random Seed: %d" % (conf.seed)) flog.write(f'Random Seed: {conf.seed}\n') random.seed(conf.seed) np.random.seed(conf.seed) torch.manual_seed(conf.seed) # save config torch.save(conf, os.path.join(conf.ckpt_path, conf.exp_name, 'conf.pth')) # create models encoder = models.RecursiveEncoder(conf, variational=True, probabilistic=not conf.non_variational) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # create optimizers encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr) decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr) optimizers = [encoder_opt, decoder_opt] optimizer_names = ['encoder', 'decoder'] # learning rate scheduler encoder_scheduler = torch.optim.lr_scheduler.StepLR( encoder_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) decoder_scheduler = torch.optim.lr_scheduler.StepLR( decoder_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) # create training and validation datasets and data loaders data_features = ['object', 'diff'] train_dataset = SynShapesDiffDataset(conf.train_dataset, data_features, conf.shapediff_topk) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=conf.batch_size, shuffle=True, collate_fn=utils.collate_feats) # create logs if not conf.no_console_log: header = ' Time Epoch Dataset Iteration Progress(%) LR BoxLoss StructLoss DNTypeLoss DNBoxLoss KLDivLoss L1Loss TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch # from tensorboardX import SummaryWriter from torch.utils.tensorboard import SummaryWriter tb_writer = SummaryWriter( log_dir=os.path.join(conf.log_path, conf.exp_name, 'train')) # send parameters to device for m in models: m.to(device) for o in optimizers: utils.optimizer_to_device(o, device) # start training print("Starting training ...... ") flog.write('Starting training ......\n') start_time = time.time() last_checkpoint_step = None last_train_console_log_step, last_valdt_console_log_step = None, None train_num_batch = len(train_dataloader) # train for every epoch for epoch in range(conf.epochs): if not conf.no_console_log: print(f'training run {conf.exp_name}') flog.write(f'training run {conf.exp_name}\n') print(header) flog.write(header + '\n') train_batches = enumerate(train_dataloader, 0) train_fraction_done = 0.0 # train for every batch for train_batch_ind, batch in train_batches: train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # set models to training mode for m in models: m.train() # forward pass (including logging) total_loss = forward(batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=False, step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=tb_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # optimize one step encoder_opt.zero_grad() decoder_opt.zero_grad() total_loss.backward() if False: for name, param in encoder.named_parameters(): if param.requires_grad and param.grad is not None: print(name, torch.norm(param.grad.data)) for name, param in decoder.named_parameters(): if param.requires_grad and param.grad is not None: print(name, torch.norm(param.grad.data)) exit(1) encoder_opt.step() decoder_opt.step() encoder_scheduler.step() decoder_scheduler.step() # save checkpoint with torch.no_grad(): if last_checkpoint_step is None or \ train_step - last_checkpoint_step >= conf.checkpoint_interval: print("Saving checkpoint ...... ", end='', flush=True) flog.write("Saving checkpoint ...... ") utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join( conf.ckpt_path, conf.exp_name), epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names) print("DONE") flog.write("DONE\n") last_checkpoint_step = train_step # save the final models print("Saving final checkpoint ...... ", end='', flush=True) flog.write("Saving final checkpoint ...... ") utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.ckpt_path, conf.exp_name), epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names) print("DONE") flog.write("DONE\n") flog.close()
def train(conf, train_shape_list, train_data_list, val_data_list, all_train_data_list): # create training and validation datasets and data loaders data_features = ['pcs', 'pc_pxids', 'pc_movables', 'gripper_img_target', 'gripper_direction_camera', 'gripper_forward_direction_camera', \ 'result', 'cur_dir', 'shape_id', 'trial_id', 'is_original'] # load network model model_def = utils.get_model_module(conf.model_version) # create models network = model_def.Network(conf.feat_dim) utils.printout(conf.flog, '\n' + str(network) + '\n') # create optimizers network_opt = torch.optim.Adam(network.parameters(), lr=conf.lr, weight_decay=conf.weight_decay) # learning rate scheduler network_lr_scheduler = torch.optim.lr_scheduler.StepLR( network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) # create logs if not conf.no_console_log: header = ' Time Epoch Dataset Iteration Progress(%) LR TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch from tensorboardX import SummaryWriter train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train')) val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val')) # send parameters to device network.to(conf.device) utils.optimizer_to_device(network_opt, conf.device) # load dataset train_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \ abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal) val_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \ abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal) val_dataset.load_data(val_data_list) utils.printout(conf.flog, str(val_dataset)) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \ num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn) val_num_batch = len(val_dataloader) # create a data generator datagen = DataGen(conf.num_processes_for_datagen, conf.flog) # sample succ if conf.sample_succ: sample_succ_list = [] sample_succ_dirs = [] # start training start_time = time.time() last_train_console_log_step, last_val_console_log_step = None, None # if resume start_epoch = 0 if conf.resume: # figure out the latest epoch to resume for item in os.listdir(os.path.join(conf.exp_dir, 'ckpts')): if item.endswith('-train_dataset.pth'): start_epoch = int(item.split('-')[0]) # load states for network, optimizer, lr_scheduler, sample_succ_list data_to_restore = torch.load( os.path.join(conf.exp_dir, 'ckpts', '%d-network.pth' % start_epoch)) network.load_state_dict(data_to_restore) data_to_restore = torch.load( os.path.join(conf.exp_dir, 'ckpts', '%d-optimizer.pth' % start_epoch)) network_opt.load_state_dict(data_to_restore) data_to_restore = torch.load( os.path.join(conf.exp_dir, 'ckpts', '%d-lr_scheduler.pth' % start_epoch)) network_lr_scheduler.load_state_dict(data_to_restore) # rmdir and make a new dir for the current sample-succ directory old_sample_succ_dir = os.path.join( conf.data_dir, 'epoch-%04d_sample-succ' % (start_epoch - 1)) utils.force_mkdir(old_sample_succ_dir) # train for every epoch for epoch in range(start_epoch, conf.epochs): ### collect data for the current epoch if epoch > start_epoch: utils.printout( conf.flog, f' [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Waiting epoch-{epoch} data ]' ) train_data_list = datagen.join_all() utils.printout( conf.flog, f' [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Gathered epoch-{epoch} data ]' ) cur_data_folders = [] for item in train_data_list: item = '/'.join(item.split('/')[:-1]) if item not in cur_data_folders: cur_data_folders.append(item) for cur_data_folder in cur_data_folders: with open(os.path.join(cur_data_folder, 'data_tuple_list.txt'), 'w') as fout: for item in train_data_list: if cur_data_folder == '/'.join(item.split('/')[:-1]): fout.write(item.split('/')[-1] + '\n') # load offline-generated sample-random data for item in all_train_data_list: valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * ( epoch - 1) valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch if valid_id_l <= int(item.split('_')[-1]) < valid_id_r: train_data_list.append(item) ### start generating data for the next epoch # sample succ if conf.sample_succ: if conf.resume and epoch == start_epoch: sample_succ_list = torch.load( os.path.join(conf.exp_dir, 'ckpts', '%d-sample_succ_list.pth' % start_epoch)) else: torch.save( sample_succ_list, os.path.join(conf.exp_dir, 'ckpts', '%d-sample_succ_list.pth' % epoch)) for item in sample_succ_list: datagen.add_one_recollect_job(item[0], item[1], item[2], item[3], item[4], item[5], item[6]) sample_succ_list = [] sample_succ_dirs = [] cur_sample_succ_dir = os.path.join( conf.data_dir, 'epoch-%04d_sample-succ' % epoch) utils.force_mkdir(cur_sample_succ_dir) # start all jobs datagen.start_all() utils.printout( conf.flog, f' [ {strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Started generating epoch-{epoch+1} data ]' ) ### load data for the current epoch if conf.resume and epoch == start_epoch: train_dataset = torch.load( os.path.join(conf.exp_dir, 'ckpts', '%d-train_dataset.pth' % start_epoch)) else: train_dataset.load_data(train_data_list) utils.printout(conf.flog, str(train_dataset)) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \ num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn) train_num_batch = len(train_dataloader) ### print log if not conf.no_console_log: utils.printout(conf.flog, f'training run {conf.exp_name}') utils.printout(conf.flog, header) train_batches = enumerate(train_dataloader, 0) val_batches = enumerate(val_dataloader, 0) train_fraction_done = 0.0 val_fraction_done = 0.0 val_batch_ind = -1 ### train for every batch for train_batch_ind, batch in train_batches: train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # save checkpoint if train_batch_ind == 0: with torch.no_grad(): utils.printout(conf.flog, 'Saving checkpoint ...... ') torch.save( network.state_dict(), os.path.join(conf.exp_dir, 'ckpts', '%d-network.pth' % epoch)) torch.save( network_opt.state_dict(), os.path.join(conf.exp_dir, 'ckpts', '%d-optimizer.pth' % epoch)) torch.save( network_lr_scheduler.state_dict(), os.path.join(conf.exp_dir, 'ckpts', '%d-lr_scheduler.pth' % epoch)) torch.save( train_dataset, os.path.join(conf.exp_dir, 'ckpts', '%d-train_dataset.pth' % epoch)) utils.printout(conf.flog, 'DONE') # set models to training mode network.train() # forward pass (including logging) total_loss, whole_feats, whole_pcs, whole_pxids, whole_movables = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \ step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \ log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr']) # optimize one step network_opt.zero_grad() total_loss.backward() network_opt.step() network_lr_scheduler.step() # sample succ if conf.sample_succ: network.eval() with torch.no_grad(): # sample a random EE orientation random_up = torch.randn(conf.batch_size, 3).float().to(conf.device) random_forward = torch.randn(conf.batch_size, 3).float().to(conf.device) random_left = torch.cross(random_up, random_forward) random_forward = torch.cross(random_left, random_up) random_dirs1 = F.normalize(random_up, dim=1).float() random_dirs2 = F.normalize(random_forward, dim=1).float() # test over the entire image whole_pc_scores1 = network.inference_whole_pc( whole_feats, random_dirs1, random_dirs2) # B x N whole_pc_scores2 = network.inference_whole_pc( whole_feats, -random_dirs1, random_dirs2) # B x N # add to the sample_succ_list if wanted ss_cur_dir = batch[data_features.index('cur_dir')] ss_shape_id = batch[data_features.index('shape_id')] ss_trial_id = batch[data_features.index('trial_id')] ss_is_original = batch[data_features.index('is_original')] for i in range(conf.batch_size): valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * ( epoch - 1) valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch if ('sample-succ' not in ss_cur_dir[i]) and (ss_is_original[i]) and (ss_cur_dir[i] not in sample_succ_dirs) \ and (valid_id_l <= int(ss_trial_id[i]) < valid_id_r): sample_succ_dirs.append(ss_cur_dir[i]) # choose one from the two options gt_movable = whole_movables[i].cpu().numpy() whole_pc_score1 = whole_pc_scores1[i].cpu().numpy( ) * gt_movable whole_pc_score1[whole_pc_score1 < 0.5] = 0 whole_pc_score_sum1 = np.sum( whole_pc_score1) + 1e-12 whole_pc_score2 = whole_pc_scores2[i].cpu().numpy( ) * gt_movable whole_pc_score2[whole_pc_score2 < 0.5] = 0 whole_pc_score_sum2 = np.sum( whole_pc_score2) + 1e-12 choose1or2_ratio = whole_pc_score_sum1 / ( whole_pc_score_sum1 + whole_pc_score_sum2) random_dir1 = random_dirs1[i].cpu().numpy() random_dir2 = random_dirs2[i].cpu().numpy() if np.random.random() < choose1or2_ratio: whole_pc_score = whole_pc_score1 else: whole_pc_score = whole_pc_score2 random_dir1 = -random_dir1 # sample <X, Y> on each img pp = whole_pc_score + 1e-12 ptid = np.random.choice(len(whole_pc_score), 1, p=pp / pp.sum()) X = whole_pxids[i, ptid, 0].item() Y = whole_pxids[i, ptid, 1].item() # add job to the queue str_cur_dir1 = ',' + ','.join( ['%f' % elem for elem in random_dir1]) str_cur_dir2 = ',' + ','.join( ['%f' % elem for elem in random_dir2]) sample_succ_list.append((conf.offline_data_dir, str_cur_dir1, str_cur_dir2, \ ss_cur_dir[i].split('/')[-1], cur_sample_succ_dir, X, Y)) # validate one batch while val_fraction_done <= train_fraction_done and val_batch_ind + 1 < val_num_batch: val_batch_ind, val_batch = next(val_batches) val_fraction_done = (val_batch_ind + 1) / val_num_batch val_step = (epoch + val_fraction_done) * train_num_batch - 1 log_console = not conf.no_console_log and (last_val_console_log_step is None or \ val_step - last_val_console_log_step >= conf.console_log_interval) if log_console: last_val_console_log_step = val_step # set models to evaluation mode network.eval() with torch.no_grad(): # forward pass (including logging) __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \ step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \ log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])