예제 #1
0
def main(args, myargs):
    cfg = setup(args, myargs.config)
    myargs = D2Utils.setup_myargs_for_multiple_processing(myargs)
    # seed_utils.set_random_seed(cfg.seed)

    build_start(cfg=cfg, args=args, myargs=myargs)

    modelarts_utils.modelarts_sync_results(args=myargs.args,
                                           myargs=myargs,
                                           join=True,
                                           end=True)
    return
예제 #2
0
def run(config, args, myargs):

  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config['base_root'] = args.outdir
  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname, 
                                 reinitialize=(not config['resume']))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr'],
                                      'config': config})

  # Prepare inception metrics: FID and IS
  get_inception_metrics = inception_utils.prepare_inception_metrics(
    config['inception_file'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'])  
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  elif config['which_train_fn'] in ['wgan_gpreal', 'wbgan_gpreal']:
    train = train_fns.wgan_gpreal_training_function(
      G, D, GD, z_, y_, ema, state_dict, config, myargs)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  modelarts_sync_results(args=args, myargs=myargs, join=False)
  # Train for specified number of epochs, although we mostly track G iterations.
  for epoch in range(state_dict['epoch'], config['num_epochs']):    
    myargs.logger.info('Epoch: %d/%d'%(epoch, config['num_epochs']))
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
      pbar = utils.progress(
        loaders[0],
        displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta',
        file=myargs.stdout)
    else:
      pbar = tqdm(loaders[0])
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()
      if config['D_fp16']:
        x, y = x.to(device).half(), y.to(device)
      else:
        x, y = x.to(device), y.to(device)
      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)
      myargs.textlogger.log(state_dict['itr'], **metrics)
      for tag, v in metrics.items():
        myargs.writer.add_scalar('metrics/%s' % tag, v, state_dict['itr'])
      
      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and \
              (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
      if not (state_dict['itr'] % 100) or i == 0:
        gpu_str = gpu_usage.get_gpu_memory_map()
        myargs.stderr.write(gpu_str)

      # If using my progbar, print metrics.
      if config['pbar'] == 'mine':
        print(', '.join(['itr: %d' % state_dict['itr']]
                        + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]), end=' ', file=myargs.stdout)
        myargs.stdout.flush()

      # Save weights and copies as configured at specified interval
      if not ((state_dict['itr'] - 1) % config['save_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 
                                  state_dict, config, experiment_name)
        modelarts_sync_results(args=args, myargs=myargs, join=False)

      # Test every specified interval
      if not (state_dict['itr'] % config['test_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log,
                       writer=myargs.writer)
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
  # End training
  modelarts_sync_results(args=args, myargs=myargs, join=True, end=True)
예제 #3
0
    def test_bash(self):
        """
    Usage:
        export CUDA_VISIBLE_DEVICES=2,3,4,5
        export PORT=6006
        export TIME_STR=1
        export PYTHONPATH=../submodule:..
        python -c "import test_bash; \
        test_bash.TestingUnit().test_bash()"
    :return:
    """
        if 'CUDA_VISIBLE_DEVICES' not in os.environ:
            os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3, 4, 5, 6, 7'
        if 'PORT' not in os.environ:
            os.environ['PORT'] = '6106'
        if 'TIME_STR' not in os.environ:
            os.environ['TIME_STR'] = '0'

        # func name
        outdir = os.path.join('results', sys._getframe().f_code.co_name)
        myargs = argparse.Namespace()

        def build_args():
            argv_str = f"""
            --config ../configs/virtual_terminal.yaml \
            --resume False --resume_path None
            --resume_root None
            """
            parser = utils.args_parser.build_parser()
            if len(sys.argv) == 1:
                args = parser.parse_args(args=argv_str.split())
            else:
                args = parser.parse_args()
            args.CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES']
            args = utils.config_utils.DotDict(vars(args))
            return args, argv_str

        args, argv_str = build_args()

        try:
            # Clean log_obs dir
            import moxing as mox
            assert os.environ['DLS_TRAIN_URL']
            log_obs = os.environ['DLS_TRAIN_URL']
            if mox.file.exists(log_obs):
                mox.file.remove(log_obs, recursive=True)
            mox.file.make_dirs(log_obs)
        except:
            pass
        args.outdir = outdir
        args, myargs = utils.config.setup_args_and_myargs(args=args,
                                                          myargs=myargs)
        modelarts_record_bash_command(args, myargs)

        old_command = ''
        myargs.logger.info('Begin loop.')
        # Create bash_command.sh
        bash_file = os.path.join(args.outdir, 'bash_command.sh')
        with open(bash_file, 'w') as f:
            pass
        cwd = os.getcwd()
        # copy outdir to outdir_obs
        modelarts_utils.modelarts_sync_results(args, myargs, join=True)
        while True:
            try:
                import moxing as mox
                # copy oudir_obs to outdir
                time.sleep(3)
                mox.file.copy_parallel(args.outdir_obs, args.outdir)
            except:
                pass
            shutil.copy(bash_file, cwd)
            try:
                with open(args.configfile, 'rt') as handle:
                    config = yaml.load(handle)
                    config = EasyDict(config)
                command = config.command
            except:
                print('Parse config.yaml error!')
                command = None
                old_command = ''
            if command != old_command:
                old_command = command
                if type(command) is list and command[0].startswith('bash'):
                    modelarts_record_bash_command(args, myargs, command[0])
                    p = Worker(name='Command worker', args=(command[0], ))
                    p.start()
                elif type(command) is list:
                    command = list(map(str, command))
                    # command = ' '.join(command)
                    print('===Execute: %s' % command)
                    err_f = open(os.path.join(args.outdir, 'err.txt'), 'w')
                    try:
                        cwd = os.getcwd()
                        return_str = subprocess.check_output(command,
                                                             encoding='utf-8',
                                                             cwd=cwd,
                                                             shell=True)
                        print(return_str, file=err_f, flush=True)
                    except subprocess.CalledProcessError as e:
                        print("Oops!\n",
                              e.output,
                              "\noccured.",
                              file=err_f,
                              flush=True)
                        print(e.returncode, file=err_f, flush=True)
                    err_f.close()

                    # os.system(command)
                modelarts_utils.modelarts_sync_results(args, myargs, join=True)
            if hasattr(args, 'outdir_obs'):
                log_obs = os.environ['DLS_TRAIN_URL']
                jobs_file_obs = os.path.join(log_obs, 'jobs.txt')
                jobs_file = os.path.join(args.outdir, 'jobs.txt')
                if mox.file.exists(jobs_file_obs):
                    mox.file.copy(jobs_file_obs, jobs_file)
                mox.file.copy_parallel(args.outdir, args.outdir_obs)
예제 #4
0
def do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments, cfg, myargs, distributed):
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration * cfg.SOLVER.IMS_PER_BATCH
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            Trainer.summary_dict2txtfig(dict_data=loss_dict_reduced,
                                        prefix='do_da_train',
                                        step=iteration,
                                        textlogger=myargs.textlogger,
                                        in_one_axe=False)
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
            modelarts_sync_results(args=myargs.args,
                                   myargs=myargs,
                                   join=False,
                                   end=False)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)
        if iteration % myargs.config.EVAL_PERIOD == 0 or iteration == max_iter:
            from tools.train_net import test
            eval_rets = test(cfg=cfg, model=model, distributed=distributed)
            default_dict = Trainer.dict_of_dicts2defaultdict(eval_rets)
            Trainer.summary_defaultdict2txtfig(default_dict=default_dict,
                                               prefix='eval',
                                               step=iteration,
                                               textlogger=myargs.textlogger)
            modelarts_sync_results(args=myargs.args,
                                   myargs=myargs,
                                   join=False,
                                   end=False)
            model.train()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
예제 #5
0
 def modelarts(self, join=False, end=False):
     modelarts_utils.modelarts_sync_results(self.args,
                                            self.myargs,
                                            join=join,
                                            end=end)
def compute_intra_FID(cfg, args, myargs):
    # register all class of ImageNet for dataloader
    # from template_lib.d2.data.build_ImageNet_per_class import ImageNetDatasetPerClassMapper

    imagenet_root_dir = cfg.start.imagenet_root_dir
    model_path = cfg.start.model_path
    model_file = cfg.start.model_file
    use_last_checkpoint = cfg.start.use_last_checkpoint
    eval_tf = cfg.start.eval_tf
    tf_fid_stats_dir = cfg.start.tf_fid_stats_dir
    num_inception_images = cfg.start.num_inception_images // comm.get_world_size(
    )
    intra_FID_tfs_file = cfg.start.intra_FID_tfs_file
    eval_torch = cfg.start.eval_torch
    torch_fid_stats_dir = cfg.start.torch_fid_stats_dir
    intra_FID_torchs_file = cfg.start.intra_FID_torchs_file

    intra_FID_tfs_file = os.path.join(myargs.args.outdir, intra_FID_tfs_file)
    intra_FID_torchs_file = os.path.join(myargs.args.outdir,
                                         intra_FID_torchs_file)

    model = build_trainer(cfg, myargs=myargs)

    # optims_dict = model.build_optimizer()
    checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR)
    if use_last_checkpoint:
        model_path = _get_last_checkpoint_file(model_dir=model_path)
    else:
        model_path = os.path.join(model_path, model_file)

    start_iter = (checkpointer.resume_or_load(
        model_path, resume=args.resume).get("iteration", -1) + 1)

    model.evaluate_model(iteration=0, fixed_arc=model.fixed_arc)

    classes, class_to_idx = find_classes(imagenet_root_dir)
    intra_FID_tfs = []
    intra_FID_torchs = []
    for class_dir, idx in tqdm.tqdm(
            class_to_idx.items(),
            desc=f"compute intra FID {myargs.args.time_str_suffix}",
            file=myargs.stdout):
        if eval_tf:
            mu_sigma = np.load(os.path.join(tf_fid_stats_dir, f'{idx}.npz'))
            class_mu, class_sigma = mu_sigma['mu'], mu_sigma['sigma']
            sample_func = functools.partial(_sample_func_with_arcs,
                                            G=model.G,
                                            z=model.z_test,
                                            y=idx,
                                            arcs=model.fixed_arc)

            mu, sigma = model.FID_IS_tf(
                sample_func,
                return_fid_stat=True,
                num_inception_images=num_inception_images,
                stdout=myargs.stdout)
            if comm.is_main_process():
                intra_FID_tf = model.FID_IS_tf._calculate_frechet_distance(
                    mu, sigma, class_mu, class_sigma)
                intra_FID_tfs.append(intra_FID_tf)
                Trainer.summary_dict2txtfig(
                    dict_data=dict(intra_FID_tf=intra_FID_tf),
                    prefix='intraFIDtf',
                    step=idx,
                    textlogger=myargs.textlogger)
                modelarts_utils.modelarts_sync_results(args=myargs.args,
                                                       myargs=myargs,
                                                       join=False)

        if eval_torch:
            mu_sigma = np.load(os.path.join(torch_fid_stats_dir, f'{idx}.npz'))
            class_mu, class_sigma = mu_sigma['mu'], mu_sigma['sigma']
            sample_func = functools.partial(_sample_func_with_arcs,
                                            G=model.G,
                                            z=model.z_test,
                                            y=idx,
                                            arcs=model.fixed_arc)

            mu, sigma = model.FID_IS_pytorch(
                sample_func,
                return_fid_stat=True,
                num_inception_images=num_inception_images,
                stdout=myargs.stdout)
            if comm.is_main_process():
                intra_FID_torch = model.FID_IS_pytorch._calculate_frechet_distance(
                    mu, sigma, class_mu, class_sigma)
                intra_FID_torchs.append(intra_FID_torch)
                Trainer.summary_dict2txtfig(
                    dict_data=dict(intra_FID_torch=intra_FID_torch),
                    prefix='intraFIDtorch',
                    step=idx,
                    textlogger=myargs.textlogger)
                modelarts_utils.modelarts_sync_results(args=myargs.args,
                                                       myargs=myargs,
                                                       join=False)

    if len(intra_FID_tfs) > 0:
        intra_FID_tfs = np.array(intra_FID_tfs)
        np.savez(intra_FID_tfs_file, intra_FID_tfs=intra_FID_tfs)
    if len(intra_FID_torchs) > 0:
        intra_FID_torchs = np.array(intra_FID_torchs)
        np.savez(intra_FID_torchs_file, intra_FID_torchs=intra_FID_torchs)
    comm.synchronize()
    return
예제 #7
0
def main(myargs):
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args([])
    args = config2args(myargs.config, args)
    cfg.OUTPUT_DIR = os.path.join(myargs.args.outdir, 'maskrcnn_benchmark')

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = num_gpus > 1

    if args.distributed:
        myargs = setup_myargs_for_multiple_processing(myargs)
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    if 'opts_private' in args:
        cfg.merge_from_list(args.opts_private)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    if comm.is_main_process():
        # Note: some of our scripts may expect the existence of
        # config.yaml in output directory
        path = os.path.join(output_dir, "config.yaml")
        with open(path, "w") as f:
            f.write(cfg.dump())
        logger.info("Full config saved to {}".format(os.path.abspath(path)))

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    model = train(cfg, args.local_rank, args.distributed, myargs=myargs)

    if not args.skip_test:
        test(cfg, model, args.distributed)

    modelarts_sync_results(args=myargs.args, myargs=myargs, join=True, end=True)
예제 #8
0
def train(cfg, args, myargs):
    dataset_name = cfg.start.dataset_name
    IMS_PER_BATCH = cfg.start.IMS_PER_BATCH
    max_epoch = cfg.start.max_epoch
    ASPECT_RATIO_GROUPING = cfg.start.ASPECT_RATIO_GROUPING
    NUM_WORKERS = cfg.start.NUM_WORKERS
    checkpoint_period = cfg.start.checkpoint_period

    cfg.defrost()
    cfg.DATASETS.TRAIN = (dataset_name, )
    cfg.SOLVER.IMS_PER_BATCH = IMS_PER_BATCH
    cfg.DATALOADER.ASPECT_RATIO_GROUPING = ASPECT_RATIO_GROUPING
    cfg.DATALOADER.NUM_WORKERS = NUM_WORKERS
    cfg.freeze()

    # build dataset
    mapper = build_dataset_mapper(cfg)
    data_loader = build_detection_train_loader(cfg, mapper=mapper)
    metadata = MetadataCatalog.get(dataset_name)
    num_images = metadata.get('num_images')
    iter_every_epoch = num_images // IMS_PER_BATCH
    max_iter = iter_every_epoch * max_epoch

    model = build_trainer(cfg,
                          myargs=myargs,
                          iter_every_epoch=iter_every_epoch)
    model.train()

    logger.info("Model:\n{}".format(model))

    # optimizer = build_optimizer(cfg, model)
    optims_dict = model.build_optimizer()
    # scheduler = build_lr_scheduler(cfg, optimizer)

    checkpointer = DetectionCheckpointer(model.get_saved_model(),
                                         cfg.OUTPUT_DIR, **optims_dict)
    start_iter = (checkpointer.resume_or_load(
        cfg.MODEL.WEIGHTS, resume=args.resume).get("iteration", -1) + 1)

    checkpoint_period = eval(checkpoint_period,
                             dict(iter_every_epoch=iter_every_epoch))
    periodic_checkpointer = PeriodicCheckpointer(checkpointer,
                                                 checkpoint_period,
                                                 max_iter=max_iter)

    logger.info("Starting training from iteration {}".format(start_iter))
    modelarts_utils.modelarts_sync_results(args=myargs.args,
                                           myargs=myargs,
                                           join=True,
                                           end=False)
    with EventStorage(start_iter) as storage:
        pbar = zip(data_loader, range(start_iter, max_iter))
        if comm.is_main_process():
            pbar = tqdm.tqdm(
                pbar,
                desc=f'train, {myargs.args.time_str_suffix}, '
                f'iters {iter_every_epoch} * bs {IMS_PER_BATCH} = imgs {iter_every_epoch*IMS_PER_BATCH}',
                file=myargs.stdout,
                initial=start_iter,
                total=max_iter)

        for data, iteration in pbar:
            comm.synchronize()
            iteration = iteration + 1
            storage.step()

            model.train_func(data, iteration - 1, pbar=pbar)

            periodic_checkpointer.step(iteration)
            pass
    modelarts_utils.modelarts_sync_results(args=myargs.args,
                                           myargs=myargs,
                                           join=True,
                                           end=True)
    comm.synchronize()