Пример #1
0
def score(args, trainer, task, epoch_itr, subset):

    log_start(key=constants.EVAL_START,
              metadata={'epoch_num': epoch_itr.epoch},
              sync=False)
    begin = time.time()

    if not subset in task.datasets.keys():
        task.load_dataset(subset)

    src_dict = deepcopy(task.source_dictionary
                        )  # This is necessary, generation of translations
    tgt_dict = deepcopy(
        task.target_dictionary
    )  # alters target dictionary messing up with the rest of training

    model = trainer.get_model()

    # Initialize data iterator
    itr = data.EpochBatchIterator(
        dataset=task.dataset(subset),
        max_tokens=min(2560, args.max_tokens),
        max_sentences=max(
            8,
            min((math.ceil(1024 / args.distributed_world_size) // 4) * 4,
                128)),
        max_positions=(256, 256),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        seq_len_multiple=args.seq_len_multiple,
        # Use a large growth factor to get fewer buckets.
        # Fewer buckets yield faster eval since batches are filled from single bucket
        # and eval dataset is small.
        bucket_growth_factor=1.2,
        batching_scheme=args.batching_scheme,
        batch_multiple_strategy=args.batch_multiple_strategy,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    translator = SequenceGenerator(
        [model],
        tgt_dict,
        beam_size=args.beam,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        minlen=args.min_len,
    )
    # Generate and compute BLEU
    ref_toks = []
    sys_toks = []
    num_sentences = 0
    has_target = True
    if args.log_translations:
        log = open(
            os.path.join(
                args.save_dir,
                'translations_epoch{}_{}'.format(epoch_itr.epoch,
                                                 args.distributed_rank)), 'w+')
    with progress_bar.build_progress_bar(args, itr) as progress:
        translations = translator.generate_batched_itr(
            progress,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=True,
            timer=gen_timer,
            prefix_size=args.prefix_size,
        )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and grount truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            src_str = src_dict.string(src_tokens, args.remove_bpe)
            if has_target:
                target_str = tgt_dict.string(target_tokens, args.remove_bpe)

            if args.log_translations:
                log.write('S-{}\t{}\n'.format(sample_id, src_str))
                if has_target:
                    log.write('T-{}\t{}\n'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=None,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe)
                if args.log_translations:
                    log.write('H-{}\t{}\t{}\n'.format(sample_id, hypo['score'],
                                                      hypo_str))
                    # log.write(str(hypo_tokens))
                    log.write('P-{}\t{}\n'.format(
                        sample_id, ' '.join(
                            map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))))

                # Score only the top hypothesis
                if has_target and i == 0:
                    src_str = detokenize_subtokenized_sentence(src_str)
                    target_str = detokenize_subtokenized_sentence(target_str)
                    hypo_str = detokenize_subtokenized_sentence(hypo_str)
                    sys_tok = bleu_tokenize(
                        (hypo_str.lower() if args.ignore_case else hypo_str))
                    ref_tok = bleu_tokenize((target_str.lower() if
                                             args.ignore_case else target_str))
                    sys_toks.append(sys_tok)
                    ref_toks.append(ref_tok)

            wps_meter.update(src_tokens.size(0))
            progress.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    bleu_score_reference = compute_bleu(ref_toks, sys_toks, args)
    bleu_score_reference_str = '{:.4f}'.format(bleu_score_reference)
    if args.log_translations:
        log.close()
    if gen_timer.sum != 0:
        print(
            '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
            .format(num_sentences, gen_timer.n, gen_timer.sum,
                    num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: bleu_score={}'.format(
            subset, args.beam, bleu_score_reference_str))
    print('| Eval completed in: {:.2f}s'.format(time.time() - begin))
    log_end(key=constants.EVAL_STOP,
            metadata={'epoch_num': epoch_itr.epoch},
            sync=False)

    return bleu_score_reference
Пример #2
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)

    mllog.config(filename=os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'transformer.log'))
    mllogger = mllog.get_mllogger()
    mllogger.logger.propagate = False

    log_start(key=constants.INIT_START, log_all_ranks=True)

    # preinit and warmup streams/groups for allreduce communicators
    allreduce_communicators = None
    if args.distributed_world_size > 1 and args.enable_parallel_backward_allred_opt:
        allreduce_groups = [
            torch.distributed.new_group()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        allreduce_streams = [
            torch.cuda.Stream()
            for _ in range(args.parallel_backward_allred_cuda_nstreams)
        ]
        for group, stream in zip(allreduce_groups, allreduce_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1),
                                             group=group)
        allreduce_communicators = (allreduce_groups, allreduce_streams)

    if args.max_tokens is None:
        args.max_tokens = 6000

    print(args)

    log_event(key=constants.GLOBAL_BATCH_SIZE,
              value=args.max_tokens * args.distributed_world_size)
    log_event(key=constants.OPT_NAME, value=args.optimizer)
    assert (len(args.lr) == 1)
    log_event(key=constants.OPT_BASE_LR,
              value=args.lr[0] if len(args.lr) == 1 else args.lr)
    log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup_updates)
    assert (args.max_source_positions == args.max_target_positions)
    log_event(key=constants.MAX_SEQUENCE_LENGTH,
              value=args.max_target_positions,
              metadata={'method': 'discard'})
    log_event(key=constants.OPT_ADAM_BETA_1, value=eval(args.adam_betas)[0])
    log_event(key=constants.OPT_ADAM_BETA_2, value=eval(args.adam_betas)[1])
    log_event(key=constants.OPT_ADAM_EPSILON, value=args.adam_eps)
    log_event(key=constants.SEED, value=args.seed)

    # L2 Sector Promotion
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = ctypes.CDLL('libcudart.so').cudaDeviceSetLimit(
        ctypes.c_int(0x05), ctypes.c_int(128))
    result = ctypes.CDLL('libcudart.so').cudaDeviceGetLimit(
        pValue, ctypes.c_int(0x05))

    worker_seeds, shuffling_seeds = setup_seeds(
        args.seed,
        args.max_epoch + 1,
        torch.device('cuda'),
        args.distributed_rank,
        args.distributed_world_size,
    )
    worker_seed = worker_seeds[args.distributed_rank]
    print(
        f'Worker {args.distributed_rank} is using worker seed: {worker_seed}')
    torch.manual_seed(worker_seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))

    # Build trainer
    if args.fp16:
        if args.distributed_weight_update != 0:
            from fairseq.fp16_trainer import DistributedFP16Trainer
            trainer = DistributedFP16Trainer(
                args,
                task,
                model,
                criterion,
                allreduce_communicators=allreduce_communicators)
        else:
            from fairseq.fp16_trainer import FP16Trainer
            trainer = FP16Trainer(
                args,
                task,
                model,
                criterion,
                allreduce_communicators=allreduce_communicators)
    else:
        if torch.cuda.get_device_capability(0)[0] >= 7:
            print(
                '| NOTICE: your device may support faster training with --fp16'
            )

        trainer = Trainer(args,
                          task,
                          model,
                          criterion,
                          allreduce_communicators=None)

    #if (args.online_eval or args.target_bleu) and not args.remove_bpe:
    #    args.remove_bpe='@@ '

    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()

    # Send a dummy batch to warm the caching allocator
    dummy_batch = language_pair_dataset.get_dummy_batch_isolated(
        args.max_tokens, max_positions, 8)
    trainer.dummy_train_step(dummy_batch)

    # Train until the learning rate gets too small or model reaches target score
    max_epoch = args.max_epoch if args.max_epoch >= 0 else math.inf
    max_update = args.max_update or math.inf
    tgt_bleu = args.target_bleu or math.inf
    current_bleu = 0.0
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]

    # mlperf compliance synchronization
    if args.distributed_world_size > 1:
        assert (torch.distributed.is_initialized())
        torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
        torch.cuda.synchronize()

    log_end(key=constants.INIT_STOP, sync=False)

    log_start(key=constants.RUN_START, sync=True)
    # second sync after RUN_START tag is printed.
    # this ensures no rank touches data until after RUN_START tag is printed.
    barrier()

    # Load dataset splits
    load_dataset_splits(task, ['train', 'test'])

    log_event(key=constants.TRAIN_SAMPLES,
              value=len(task.dataset(args.train_subset)),
              sync=False)
    log_event(key=constants.EVAL_SAMPLES,
              value=len(task.dataset(args.gen_subset)),
              sync=False)

    ctr = 0

    start = time.time()
    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        dataloader_num_workers=args.dataloader_num_workers,
        dataloader_pin_memory=args.enable_dataloader_pin_memory,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seeds=shuffling_seeds,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        epoch=epoch_itr.epoch if ctr is not 0 else 0,
        bucket_growth_factor=args.bucket_growth_factor,
        seq_len_multiple=args.seq_len_multiple,
        batching_scheme=args.batching_scheme,
        batch_multiple_strategy=args.batch_multiple_strategy,
    )
    print("got epoch iterator", time.time() - start)

    # Main training loop
    while lr >= args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update and current_bleu < tgt_bleu:
        first_epoch = epoch_itr.epoch + 1
        log_start(key=constants.BLOCK_START,
                  metadata={
                      'first_epoch_num': first_epoch,
                      'epoch_count': 1
                  },
                  sync=False)
        log_start(key=constants.EPOCH_START,
                  metadata={'epoch_num': first_epoch},
                  sync=False)

        gc.disable()

        # Load the latest checkpoint if one is available
        if ctr is 0:
            load_checkpoint(args, trainer, epoch_itr)

        # train for one epoch
        start = time.time()
        #exit(1)
        train(args, trainer, task, epoch_itr, shuffling_seeds)
        print("epoch time ", time.time() - start)

        start = time.time()
        log_end(key=constants.EPOCH_STOP,
                metadata={'epoch_num': first_epoch},
                sync=False)

        # Eval BLEU score
        if args.online_eval or (not tgt_bleu is math.inf):
            current_bleu = score(args, trainer, task, epoch_itr,
                                 args.gen_subset)
            log_event(key=constants.EVAL_ACCURACY,
                      value=float(current_bleu) / 100.0,
                      metadata={'epoch_num': first_epoch})

        gc.enable()

        # Only use first validation loss to update the learning rate
        #lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # Save checkpoint
        #if epoch_itr.epoch % args.save_interval == 0:
        #    save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        ctr = ctr + 1
        print("validation and scoring ", time.time() - start)
        log_end(key=constants.BLOCK_STOP,
                metadata={'first_epoch_num': first_epoch},
                sync=False)

    train_meter.stop()
    status = 'success' if current_bleu >= tgt_bleu else 'aborted'
    log_end(key=constants.RUN_STOP, metadata={'status': status})
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Пример #3
0
 def log_callback(future):
     log_end(key=mlperf_constants.EVAL_STOP,
             metadata={'epoch_num': epoch})
     log_event(key=mlperf_constants.EVAL_ACCURACY,
               value=future.result() / 100,
               metadata={'epoch_num': epoch})
Пример #4
0
    def validate(self,
                 val_iterator,
                 epoch=1,
                 annotation_file=None,
                 cocoapi_threads=1,
                 log_interval=None):
        """Test on validation dataset."""
        log_start(key=mlperf_constants.EVAL_START,
                  metadata={'epoch_num': epoch})
        time_ticks = [time.time()]
        time_messages = []

        # save a copy of weights to temp dir
        if self.async_executor and self.save_prefix and hvd.rank() == 0:
            save_fname = os.path.join(tempfile.gettempdir(),
                                      f'temp_ssd_mxnet_epoch{epoch}.params')
            self.net.save_parameters(save_fname)
        time_ticks.append(time.time())
        time_messages.append('save_parameters')

        results = self.infer(data_iterator=val_iterator,
                             log_interval=log_interval)
        time_ticks.append(time.time())
        time_messages.append('inference')

        # all gather results from all ranks
        if hvd.size() > 1:
            results = self.allgather(results)
        time_ticks.append(time.time())
        time_messages.append('allgather')

        # convert to numpy (cocoapi doesn't take mxnet ndarray)
        results = results.asnumpy()
        time_ticks.append(time.time())
        time_messages.append('asnumpy')

        time_ticks = np.array(time_ticks)
        elpased_time = time_ticks[1:] - time_ticks[:-1]
        validation_log_msg = '[Validation] '
        for msg, t in zip(time_messages, elpased_time):
            validation_log_msg += f'{msg}: {t*1000.0:.3f} [ms], '
        # TODO(ahmadki): val size is hard coded :(
        validation_log_msg += f'speed: {5000.0/(time_ticks[-1]-time_ticks[0]):.3f} [imgs/sec]'

        # TODO(ahmadki): remove time measurements
        logging.info(validation_log_msg)

        # Evaluate(score) results
        map_score = -1
        if self.async_executor:
            if hvd.rank() == 0:
                self.async_executor.submit(tag=str(epoch),
                                           fn=coco_map_score,
                                           results=results,
                                           annotation_file=annotation_file,
                                           num_threads=cocoapi_threads)

                def log_callback(future):
                    log_end(key=mlperf_constants.EVAL_STOP,
                            metadata={'epoch_num': epoch})
                    log_event(key=mlperf_constants.EVAL_ACCURACY,
                              value=future.result() / 100,
                              metadata={'epoch_num': epoch})

                self.async_executor.add_done_callback(tag=str(epoch),
                                                      fn=log_callback)
        else:
            if hvd.rank() == 0:
                map_score = coco_map_score(results=results,
                                           annotation_file=annotation_file,
                                           num_threads=cocoapi_threads)
            map_score = comm.bcast(map_score, root=0)
            log_end(key=mlperf_constants.EVAL_STOP,
                    metadata={'epoch_num': epoch})
            log_event(key=mlperf_constants.EVAL_ACCURACY,
                      value=map_score / 100,
                      metadata={'epoch_num': epoch})
        return map_score
Пример #5
0
    def train_epoch(self,
                    data_iterator,
                    global_train_batch_size,
                    iterations_per_epoch,
                    epoch=1,
                    log_interval=None,
                    profile_start=None,
                    profile_stop=None):
        current_iter = (epoch - 1) * iterations_per_epoch + 1
        timing_iter_count = 0
        timing_iter_last_tick = time.time()
        epoch_start_time = time.time()
        log_start(key=mlperf_constants.EPOCH_START,
                  metadata={
                      'epoch_num': epoch,
                      'current_iter_num': 0
                  })  # FIXME(mfrank)
        self.metric.reset()  # Reset epoch metrics
        for i, (images, box_targets, cls_targets) in enumerate(data_iterator):
            if profile_start is not None and current_iter == profile_start:
                cu.cuda_profiler_start()

            if profile_stop is not None and current_iter >= profile_stop:
                if profile_start is not None and current_iter >= profile_start:
                    # we turned cuda profiling on, better turn it off too
                    cu.cuda_profiler_stop()
                return 1

            self.metric.reset_local()  # Reset iter metrics
            lr = self.lr_scheduler(current_epoch=epoch,
                                   current_iter=current_iter)
            self.trainer.set_learning_rate(lr)  # Set Learning rate
            sum_loss = self.train_step(images=images,
                                       cls_targets=cls_targets,
                                       box_targets=box_targets)
            self.trainer.step(1)
            sum_loss = sum_loss.as_in_context(mx.cpu())
            self.metric.update(0, sum_loss)  # Update metric

            timing_iter_count = timing_iter_count + 1
            if log_interval and not current_iter % log_interval:
                name0, loss0 = self.metric.get()
                mx.nd.waitall()
                timing_tick = time.time()
                iter_time = timing_tick - timing_iter_last_tick
                iteration_prefix = (
                    f'[Training][Iteration {current_iter}][Epoch {epoch}, '
                    f'Batch {i+1}/{iterations_per_epoch}]')
                if hvd.rank() == 0:
                    logging.info((
                        f'{iteration_prefix} '
                        f'lr: {lr:.5f}, '
                        f'training time: {iter_time*1000.0/timing_iter_count:.3f} [ms], '
                        f'speed: {global_train_batch_size*timing_iter_count/iter_time:.3f} [imgs/sec], '
                        f'{name0}={loss0:.3f}'))
                # TODO(ahmadki): remove once NaN issues are solved
                if np.isnan(loss0):
                    logging.info(
                        f'{iteration_prefix} NaN detected in rank {hvd.rank()}. terminating.'
                    )
                    return 2
                timing_iter_count = 0
                timing_iter_last_tick = timing_tick
            current_iter = current_iter + 1

        name0, loss0 = self.metric.get_global()
        mx.nd.waitall()  # cpu has been launching async
        epoch_time = time.time() - epoch_start_time
        if log_interval and hvd.rank() == 0:
            logging.info((
                f'[Training][Epoch {epoch}] '
                f'training time: {epoch_time:.3f} [sec],'
                f'avg speed: {(i+1)*global_train_batch_size/epoch_time:.3f} [imgs/sec],'
                f'{name0}={loss0:.3f}'))

        log_end(key=mlperf_constants.EPOCH_STOP, metadata={'epoch_num': epoch})
        return 0
Пример #6
0
def main(async_executor=None):
    # Setup MLPerf logger
    mllog.config()
    mllogger = mllog.get_mllogger()
    mllogger.logger.propagate = False
    # Start MLPerf benchmark
    log_start(key=mlperf_constants.INIT_START, uniq=False)

    # Parse args
    args = parse_args()

    ############################################################################
    # Initialize various libraries (horovod, logger, amp ...)
    ############################################################################
    # Initialize async executor
    if args.async_val:
        assert async_executor is not None, 'Please use ssd_main_async.py to launch with async support'
    else:
        # (Force) disable async validation
        async_executor = None

    # Initialize horovod
    hvd.init()

    # Initialize AMP
    if args.precision == 'amp':
        amp.init(layout_optimization=True)

    # Set MXNET_SAFE_ACCUMULATION=1 if necessary
    if args.precision == 'fp16':
        os.environ["MXNET_SAFE_ACCUMULATION"] = "1"

    # Results folder
    network_name = f'ssd_{args.backbone}_{args.data_layout}_{args.dataset}_{args.data_shape}'
    save_prefix = None
    if args.results:
        save_prefix = os.path.join(args.results, network_name)
    else:
        logging.info(
            "No results folder was provided. The script will not write logs or save weight to disk"
        )

    # Initialize logger
    log_file = None
    if args.results:
        log_file = f'{save_prefix}_{args.mode}_{hvd.rank()}.log'
    setup_logger(level=args.log_level
                 if hvd.local_rank() in args.log_local_ranks else 'CRITICAL',
                 log_file=log_file)

    # Set seed
    args.seed = set_seed_distributed(args.seed)
    ############################################################################

    ############################################################################
    # Validate arguments and print some useful information
    ############################################################################
    logging.info(args)

    assert not (args.resume_from and args.pretrained_backbone), (
        "--resume-from and --pretrained_backbone are "
        "mutually exclusive.")
    assert args.data_shape == 300, "only data_shape=300 is supported at the moment."
    assert args.input_batch_multiplier >= 1, "input_batch_multiplier must be >= 1"
    assert not (hvd.size() == 1 and args.gradient_predivide_factor > 1), (
        "Gradient predivide factor is not supported "
        "with a single GPU")
    if args.data_layout == 'NCHW' or args.precision == 'fp32':
        assert args.bn_group == 1, "Group batch norm doesn't support FP32 data format or NCHW data layout."
        if not args.no_fuse_bn_relu:
            logging.warning((
                "WARNING: fused batch norm relu is only supported with NHWC layout. "
                "A non fused version will be forced."))
            args.no_fuse_bn_relu = True
        if not args.no_fuse_bn_add_relu:
            logging.warning((
                "WARNING: fused batch norm add relu is only supported with NHWC layout. "
                "A non fused version will be forced."))
            args.no_fuse_bn_add_relu = True
    if args.profile_no_horovod and hvd.size() > 1:
        logging.warning(
            "WARNING: hvd.size() > 1, so must IGNORE requested --profile-no-horovod"
        )
        args.profile_no_horovod = False

    logging.info(f'Seed: {args.seed}')
    logging.info(f'precision: {args.precision}')
    if args.precision == 'fp16':
        logging.info(f'loss scaling: {args.fp16_loss_scale}')
    logging.info(f'network name: {network_name}')
    logging.info(f'fuse bn relu: {not args.no_fuse_bn_relu}')
    logging.info(f'fuse bn add relu: {not args.no_fuse_bn_add_relu}')
    logging.info(f'bn group: {args.bn_group}')
    logging.info(f'bn all reduce fp16: {args.bn_fp16}')
    logging.info(f'MPI size: {hvd.size()}')
    logging.info(f'MPI global rank: {hvd.rank()}')
    logging.info(f'MPI local rank: {hvd.local_rank()}')
    logging.info(f'async validation: {args.async_val}')
    ############################################################################

    # TODO(ahmadki): load network and anchors based on args.backbone (JoC)
    # Load network
    net = ssd_300_resnet34_v1_mlperf_coco(
        pretrained_base=False,
        nms_overlap_thresh=args.nms_overlap_thresh,
        nms_topk=args.nms_topk,
        nms_valid_thresh=args.nms_valid_thresh,
        post_nms=args.post_nms,
        layout=args.data_layout,
        fuse_bn_add_relu=not args.no_fuse_bn_add_relu,
        fuse_bn_relu=not args.no_fuse_bn_relu,
        bn_fp16=args.bn_fp16,
        norm_kwargs={'bn_group': args.bn_group})

    # precomputed anchors
    anchors_np = mlperf_xywh_anchors(image_size=args.data_shape,
                                     clip=True,
                                     normalize=True)
    if args.test_anchors and hvd.rank() == 0:
        logging.info(f'Normalized anchors: {anchors_np}')

    # Training mode
    train_net = None
    train_pipeline = None
    trainer_fn = None
    lr_scheduler = None
    if args.mode in ['train', 'train_val']:
        # Training iterator
        num_cropping_iterations = 1
        if args.use_tfrecord:
            tfrecord_files = glob.glob(
                os.path.join(args.tfrecord_root, 'train.*.tfrecord'))
            index_files = glob.glob(
                os.path.join(args.tfrecord_root, 'train.*.idx'))
            tfrecords = [(tfrecod, index)
                         for tfrecod, index in zip(tfrecord_files, index_files)
                         ]
        train_pipeline = get_training_pipeline(
            coco_root=args.coco_root if not args.use_tfrecord else None,
            tfrecords=tfrecords if args.use_tfrecord else None,
            anchors=anchors_np,
            num_shards=hvd.size(),
            shard_id=hvd.rank(),
            device_id=hvd.local_rank(),
            batch_size=args.batch_size * args.input_batch_multiplier,
            dataset_size=args.dataset_size,
            data_layout=args.data_layout,
            data_shape=args.data_shape,
            num_cropping_iterations=num_cropping_iterations,
            num_workers=args.dali_workers,
            fp16=args.precision == 'fp16',
            input_jpg_decode=args.input_jpg_decode,
            hw_decoder_load=args.hw_decoder_load,
            decoder_cache_size=min(
                (100 * 1024 + hvd.size() - 1) // hvd.size(), 12 *
                1024) if args.input_jpg_decode == 'cache' else 0,
            seed=args.seed)
        log_event(key=mlperf_constants.TRAIN_SAMPLES,
                  value=train_pipeline.epoch_size)
        log_event(key=mlperf_constants.MAX_SAMPLES,
                  value=num_cropping_iterations)

        # Training network
        train_net = SSDMultiBoxLoss(net=net,
                                    local_batch_size=args.batch_size,
                                    bulk_last_wgrad=args.bulk_last_wgrad)

        # Trainer function. SSDModel expects a function that takes 1 parameter - HybridBlock
        trainer_fn = functools.partial(
            sgd_trainer,
            learning_rate=args.lr,
            weight_decay=args.weight_decay,
            momentum=args.momentum,
            precision=args.precision,
            fp16_loss_scale=args.fp16_loss_scale,
            gradient_predivide_factor=args.gradient_predivide_factor,
            num_groups=args.horovod_num_groups,
            profile_no_horovod=args.profile_no_horovod)

        # Learning rate scheduler
        lr_scheduler = MLPerfLearningRateScheduler(
            learning_rate=args.lr,
            decay_factor=args.lr_decay_factor,
            decay_epochs=args.lr_decay_epochs,
            warmup_factor=args.lr_warmup_factor,
            warmup_epochs=args.lr_warmup_epochs,
            epoch_size=train_pipeline.epoch_size,
            global_batch_size=args.batch_size * hvd.size())

    # Validation mode
    infer_net = None
    val_iterator = None
    if args.mode in ['infer', 'val', 'train_val']:
        # Validation iterator
        tfrecord_files = glob.glob(
            os.path.join(args.tfrecord_root, 'val.*.tfrecord'))
        index_files = glob.glob(os.path.join(args.tfrecord_root, 'val.*.idx'))
        tfrecords = [(tfrecod, index)
                     for tfrecod, index in zip(tfrecord_files, index_files)]
        val_pipeline = get_inference_pipeline(
            coco_root=args.coco_root if not args.use_tfrecord else None,
            tfrecords=tfrecords if args.use_tfrecord else None,
            num_shards=hvd.size(),
            shard_id=hvd.rank(),
            device_id=hvd.local_rank(),
            batch_size=args.eval_batch_size,
            dataset_size=args.eval_dataset_size,
            data_layout=args.data_layout,
            data_shape=args.data_shape,
            num_workers=args.dali_workers,
            fp16=args.precision == 'fp16')
        log_event(key=mlperf_constants.EVAL_SAMPLES,
                  value=val_pipeline.epoch_size)

        # Inference network
        infer_net = COCOInference(net=net,
                                  ltrb=False,
                                  scale_bboxes=True,
                                  score_threshold=0.0)

        # annotations file
        cocoapi_annotation_file = os.path.join(
            args.coco_root, 'annotations', 'bbox_only_instances_val2017.json')

    # Prepare model
    model = SSDModel(net=net,
                     anchors_np=anchors_np,
                     precision=args.precision,
                     fp16_loss_scale=args.fp16_loss_scale,
                     train_net=train_net,
                     trainer_fn=trainer_fn,
                     lr_scheduler=lr_scheduler,
                     metric=mx.metric.Loss(),
                     infer_net=infer_net,
                     async_executor=async_executor,
                     save_prefix=save_prefix,
                     ctx=mx.gpu(hvd.local_rank()))

    # Do a training and validation runs on fake data.
    # this will set layers shape (needed before loading pre-trained backbone),
    # allocate tensors and and cache optimized graph.
    # Training dry run:
    logging.info('Running training dry runs')
    dummy_train_pipeline = get_training_pipeline(
        coco_root=None,
        tfrecords=[('dummy.tfrecord', 'dummy.idx')],
        anchors=anchors_np,
        num_shards=1,
        shard_id=0,
        device_id=hvd.local_rank(),
        batch_size=args.batch_size * args.input_batch_multiplier,
        dataset_size=None,
        data_layout=args.data_layout,
        data_shape=args.data_shape,
        num_workers=args.dali_workers,
        fp16=args.precision == 'fp16',
        seed=args.seed)
    dummy_train_iterator = get_training_iterator(pipeline=dummy_train_pipeline,
                                                 batch_size=args.batch_size)
    for images, box_targets, cls_targets in dummy_train_iterator:
        model.train_step(images=images,
                         box_targets=box_targets,
                         cls_targets=cls_targets)
    # Freeing memory is disabled due a bug in CUDA graphs
    # del dummy_train_pipeline
    # del dummy_train_iterator
    mx.ndarray.waitall()
    logging.info('Done')
    # Validation dry run:
    logging.info('Running inference dry runs')
    dummy_val_pipeline = get_inference_pipeline(
        coco_root=None,
        tfrecords=[('dummy.tfrecord', 'dummy.idx')],
        num_shards=1,
        shard_id=0,
        device_id=hvd.local_rank(),
        batch_size=args.eval_batch_size,
        dataset_size=None,
        data_layout=args.data_layout,
        data_shape=args.data_shape,
        num_workers=args.dali_workers,
        fp16=args.precision == 'fp16')
    dummy_val_iterator = get_inference_iterator(pipeline=dummy_val_pipeline)
    model.infer(data_iterator=dummy_val_iterator, log_interval=None)
    # Freeing memory is disabled due a bug in CUDA graphs
    # del dummy_val_pipeline
    # del dummy_val_iterator
    mx.ndarray.waitall()
    logging.info('Done')

    # re-initialize the model as a precaution in case the dry runs changed the parameters
    model.init_model(force_reinit=True)
    model.zero_grads()
    mx.ndarray.waitall()

    # load saved model or pretrained backbone
    if args.resume_from:
        model.load_parameters(filename=args.resume_from)
    elif args.pretrained_backbone:
        model.load_pretrain_backbone(picklefile_name=args.pretrained_backbone)

    # broadcast parameters
    model.broadcast_params()
    mx.ndarray.waitall()

    if args.test_initialization and hvd.rank() == 0:
        model.print_params_stats(net)

    log_end(key=mlperf_constants.INIT_STOP)

    # Main MLPerf loop (training+validation)
    mpiwrapper.barrier()
    log_start(key=mlperf_constants.RUN_START)
    mpiwrapper.barrier()
    # Real data iterators
    train_iterator = None
    val_iterator = None
    if train_pipeline:
        train_iterator = get_training_iterator(pipeline=train_pipeline,
                                               batch_size=args.batch_size,
                                               synthetic=args.synthetic)
    if val_pipeline:
        val_iterator = get_inference_iterator(pipeline=val_pipeline)
    model_map, epoch = model.train_val(train_iterator=train_iterator,
                                       start_epoch=args.start_epoch,
                                       end_epoch=args.epochs,
                                       val_iterator=val_iterator,
                                       val_interval=args.val_interval,
                                       val_epochs=args.val_epochs,
                                       annotation_file=cocoapi_annotation_file,
                                       target_map=args.target_map,
                                       train_log_interval=args.log_interval,
                                       val_log_interval=args.log_interval,
                                       save_interval=args.save_interval,
                                       cocoapi_threads=args.cocoapi_threads,
                                       profile_start=args.profile_start,
                                       profile_stop=args.profile_stop)
    status = 'success' if (model_map
                           and model_map >= args.target_map) else 'aborted'
    mx.ndarray.waitall()
    log_end(key=mlperf_constants.RUN_STOP, metadata={"status": status})

    logging.info(f'Rank {hvd.rank()} done. map={model_map} @ epoch={epoch}')
    mx.nd.waitall()
    hvd.shutdown()