def test_from_name_file_model(self): # test that loading works even if they differ by a prefix for trained_model, fresh_model in [ (self.create_model(), self.create_model()), (nn.DataParallel(self.create_model()), self.create_model()), (self.create_model(), nn.DataParallel(self.create_model())), ( nn.DataParallel(self.create_model()), nn.DataParallel(self.create_model()), ), ]: with TemporaryDirectory() as f: checkpointer = Checkpointer( trained_model, save_dir=f, save_to_disk=True ) checkpointer.save("checkpoint_file") # on different folders with TemporaryDirectory() as g: fresh_checkpointer = Checkpointer(fresh_model, save_dir=g) self.assertFalse(fresh_checkpointer.has_checkpoint()) self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "") _ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth")) for trained_p, loaded_p in zip( trained_model.parameters(), fresh_model.parameters() ): # different tensor references self.assertFalse(id(trained_p) == id(loaded_p)) # same content self.assertTrue(trained_p.equal(loaded_p))
def train(cfg, output_dir='', output_dir_merge='', output_dir_refine=''): logger = logging.getLogger('shaper.train') # build model set_random_seed(cfg.RNG_SEED) model_merge = nn.DataParallel(PointNetCls(in_channels=3, out_channels=128)).cuda() # build optimizer cfg['SCHEDULER']['StepLR']['step_size'] = 150 cfg['SCHEDULER']['MAX_EPOCH'] = 20000 optimizer_embed = build_optimizer(cfg, model_merge) # build lr scheduler scheduler_embed = build_scheduler(cfg, optimizer_embed) checkpointer_embed = Checkpointer(model_merge, optimizer=optimizer_embed, scheduler=scheduler_embed, save_dir=output_dir_merge, logger=logger) checkpoint_data_embed = checkpointer_embed.load( cfg.MODEL.WEIGHT, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES) ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD # build data loader # Reset the random seed again in case the initialization of models changes the random state. set_random_seed(cfg.RNG_SEED) # build tensorboard logger (optionally by comment) tensorboard_logger = TensorboardLogger(output_dir_merge) # train max_epoch = cfg.SCHEDULER.MAX_EPOCH start_epoch = checkpoint_data_embed.get('epoch', 0) best_metric_name = 'best_{}'.format(cfg.TRAIN.VAL_METRIC) best_metric = checkpoint_data_embed.get(best_metric_name, None) logger.info('Start training from epoch {}'.format(start_epoch)) for epoch in range(start_epoch, max_epoch): cur_epoch = epoch + 1 scheduler_embed.step() start_time = time.time() train_meters = train_one_epoch( model_merge, cur_epoch, optimizer_embed=optimizer_embed, output_dir_merge=output_dir_merge, max_grad_norm=cfg.OPTIMIZER.MAX_GRAD_NORM, freezer=None, log_period=cfg.TRAIN.LOG_PERIOD, ) epoch_time = time.time() - start_time logger.info('Epoch[{}]-Train {} total_time: {:.2f}s'.format( cur_epoch, train_meters.summary_str, epoch_time)) tensorboard_logger.add_scalars(train_meters.meters, cur_epoch, prefix='train') # checkpoint if (ckpt_period > 0 and cur_epoch % ckpt_period == 0) or cur_epoch == max_epoch: checkpoint_data_embed['epoch'] = cur_epoch checkpoint_data_embed[best_metric_name] = best_metric checkpointer_embed.save('model_{:03d}'.format(cur_epoch), **checkpoint_data_embed) return model
def train(cfg, output_dir=''): logger = logging.getLogger('shaper.train') # build model set_random_seed(cfg.RNG_SEED) model, loss_fn, metric = build_model(cfg) logger.info('Build model:\n{}'.format(str(model))) model = nn.DataParallel(model).cuda() # model = model.cuda() # build optimizer optimizer = build_optimizer(cfg, model) # build lr scheduler scheduler = build_scheduler(cfg, optimizer) # build checkpointer # Note that checkpointer will load state_dict of model, optimizer and scheduler. checkpointer = Checkpointer(model, optimizer=optimizer, scheduler=scheduler, save_dir=output_dir, logger=logger) checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES) ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD # build freezer if cfg.TRAIN.FROZEN_PATTERNS: freezer = Freezer(model, cfg.TRAIN.FROZEN_PATTERNS) freezer.freeze(verbose=True) # sanity check else: freezer = None # build data loader # Reset the random seed again in case the initialization of models changes the random state. set_random_seed(cfg.RNG_SEED) train_dataloader = build_dataloader(cfg, mode='train') val_period = cfg.TRAIN.VAL_PERIOD val_dataloader = build_dataloader(cfg, mode='val') if val_period > 0 else None # build tensorboard logger (optionally by comment) tensorboard_logger = TensorboardLogger(output_dir) # train max_epoch = cfg.SCHEDULER.MAX_EPOCH start_epoch = checkpoint_data.get('epoch', 0) best_metric_name = 'best_{}'.format(cfg.TRAIN.VAL_METRIC) best_metric = checkpoint_data.get(best_metric_name, None) logger.info('Start training from epoch {}'.format(start_epoch)) for epoch in range(start_epoch, max_epoch): cur_epoch = epoch + 1 scheduler.step() start_time = time.time() train_meters = train_one_epoch( model, loss_fn, metric, train_dataloader, optimizer=optimizer, max_grad_norm=cfg.OPTIMIZER.MAX_GRAD_NORM, freezer=freezer, log_period=cfg.TRAIN.LOG_PERIOD, ) epoch_time = time.time() - start_time logger.info('Epoch[{}]-Train {} total_time: {:.2f}s'.format( cur_epoch, train_meters.summary_str, epoch_time)) tensorboard_logger.add_scalars(train_meters.meters, cur_epoch, prefix='train') # checkpoint if (ckpt_period > 0 and cur_epoch % ckpt_period == 0) or cur_epoch == max_epoch: checkpoint_data['epoch'] = cur_epoch checkpoint_data[best_metric_name] = best_metric checkpointer.save('model_{:03d}'.format(cur_epoch), **checkpoint_data) # validate if val_period > 0 and (cur_epoch % val_period == 0 or cur_epoch == max_epoch): start_time = time.time() val_meters = validate( model, loss_fn, metric, val_dataloader, log_period=cfg.TEST.LOG_PERIOD, ) epoch_time = time.time() - start_time logger.info('Epoch[{}]-Val {} total_time: {:.2f}s'.format( cur_epoch, val_meters.summary_str, epoch_time)) tensorboard_logger.add_scalars(val_meters.meters, cur_epoch, prefix='val') # best validation if cfg.TRAIN.VAL_METRIC in val_meters.meters: cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg if best_metric is None or cur_metric > best_metric: best_metric = cur_metric checkpoint_data['epoch'] = cur_epoch checkpoint_data[best_metric_name] = best_metric checkpointer.save('model_best', **checkpoint_data) logger.info('Best val-{} = {}'.format(cfg.TRAIN.VAL_METRIC, best_metric)) return model