Esempio n. 1
0
def train300_mlperf_coco(args):

    args = setup_distributed(args)

    # Build the model
    model_options = {
        'use_nhwc': args.nhwc,
        'pad_input': args.pad_input,
        'bn_group': args.bn_group,
    }

    ssd300 = SSD300(args, args.num_classes, **model_options)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    dboxes = dboxes300_coco()
    # Note: No reason not to use optimised loss
    loss_func = OptLoss()
    loss_func.cuda()

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (args.N_gpu * args.batch_size)
    log_event(key=constants.MODEL_BN_SPAN,
              value=args.bn_group * args.batch_size)
    log_event(key=constants.GLOBAL_BATCH_SIZE, value=global_batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 2.5e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(
        1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = args.wd
    static_loss_scale = 128.

    optim = apex.optimizers.FusedSGD(ssd300.parameters(),
                                     lr=current_lr,
                                     momentum=current_momentum,
                                     weight_decay=current_weight_decay)

    ssd300, optim = apex.amp.initialize(ssd300,
                                        optim,
                                        opt_level='O2',
                                        loss_scale=static_loss_scale)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank,
                          "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     gradient_predivide_factor=args.N_gpu / 8.0,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    log_event(key=constants.OPT_BASE_LR, value=current_lr)
    log_event(key=constants.OPT_LR_DECAY_BOUNDARY_EPOCHS,
              value=args.lr_decay_epochs)
    log_event(key=constants.OPT_LR_DECAY_STEPS, value=args.lr_decay_epochs)
    log_event(key=constants.OPT_WEIGHT_DECAY, value=current_weight_decay)
    if args.warmup is not None:
        log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup)
        log_event(key=constants.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(args, args.num_classes, **model_options).cuda()

    if args.use_fp16:
        convert_network(ssd300_eval, torch.half)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300
    ssd300_eval.load_state_dict(train_model.state_dict())
    ssd300_eval.eval()

    print_message(args.local_rank, "epoch", "nbatch", "loss")

    iter_num = args.iteration
    avg_loss = 0.0

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    input_c = 4 if args.pad_input else 3
    example_shape = [args.batch_size, 300, 300, input_c
                     ] if args.nhwc else [args.batch_size, input_c, 300, 300]
    example_input = torch.randn(*example_shape).cuda()

    if args.use_fp16:
        example_input = example_input.half()
    if args.jit:
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit,
                                            example_input,
                                            check_trace=False)
        else:
            ssd300 = torch.jit.trace(module_to_jit,
                                     example_input,
                                     check_trace=False)
        # JIT the eval model too
        ssd300_eval = torch.jit.trace(ssd300_eval,
                                      example_input,
                                      check_trace=False)

    # do a dummy fprop & bprop to make sure cudnnFind etc. are timed here
    ploc, plabel = ssd300(example_input)

    # produce a single dummy "loss" to make things easier
    loss = ploc[0, 0, 0] + plabel[0, 0, 0]
    dloss = torch.randn_like(loss)
    # Cause cudnnFind for dgrad, wgrad to run
    loss.backward(dloss)

    # Necessary import in init
    from pycocotools.coco import COCO

    encoder = build_ssd300_coder()

    evaluator = AsyncEvaluator(num_threads=1)

    log_end(key=constants.INIT_STOP)

    ##### END INIT

    # This is the first place we touch anything related to data
    ##### START DATA TOUCHING
    barrier()
    log_start(key=constants.RUN_START)
    barrier()

    train_pipe = prebuild_pipeline(args)

    train_loader, epoch_size = build_pipeline(args,
                                              training=True,
                                              pipe=train_pipe)
    if args.rank == 0:
        print("epoch size is: ", epoch_size, " images")

    val_loader, inv_map, cocoGt = build_pipeline(args, training=False)
    if args.profile_gc_off:
        gc.disable()
        gc.collect()

    ##### END DATA TOUCHING
    i_eval = 0
    block_start_epoch = 1
    log_start(key=constants.BLOCK_START,
              metadata={
                  'first_epoch_num': block_start_epoch,
                  'epoch_count': args.evaluation[i_eval]
              })
    for epoch in range(args.epochs):
        for p in ssd300.parameters():
            p.grad = None

        if epoch in args.evaluation:
            # Get the existant state from the train model
            # * if we use distributed, then we want .module
            train_model = ssd300.module if args.distributed else ssd300

            if args.distributed and args.allreduce_running_stats:
                if args.rank == 0: print("averaging bn running means and vars")
                # make sure every node has the same running bn stats before
                # using them to evaluate, or saving the model for inference
                world_size = float(torch.distributed.get_world_size())
                for bn_name, bn_buf in train_model.named_buffers(recurse=True):
                    if ('running_mean' in bn_name) or ('running_var'
                                                       in bn_name):
                        torch.distributed.all_reduce(bn_buf,
                                                     op=dist.ReduceOp.SUM)
                        bn_buf /= world_size

            if args.rank == 0:
                if args.save:
                    print("saving model...")
                    if not os.path.isdir('./models'):
                        os.mkdir('./models')
                    torch.save({"model": ssd300.state_dict()},
                               "./models/iter_{}.pt".format(iter_num))

            ssd300_eval.load_state_dict(train_model.state_dict())
            # Note: No longer returns, evaluation is abstracted away inside evaluator
            coco_eval(args,
                      ssd300_eval,
                      val_loader,
                      cocoGt,
                      encoder,
                      inv_map,
                      epoch,
                      iter_num,
                      evaluator=evaluator)
            log_end(key=constants.BLOCK_STOP,
                    metadata={'first_epoch_num': block_start_epoch})
            if epoch != max(args.evaluation):
                i_eval += 1
                block_start_epoch = epoch + 1
                log_start(key=constants.BLOCK_START,
                          metadata={
                              'first_epoch_num':
                              block_start_epoch,
                              'epoch_count': (args.evaluation[i_eval] -
                                              args.evaluation[i_eval - 1])
                          })

        if epoch in args.lr_decay_epochs:
            current_lr *= args.lr_decay_factor
            print_message(
                args.rank,
                "lr decay step #" + str(bisect(args.lr_decay_epochs, epoch)))
            for param_group in optim.param_groups:
                param_group['lr'] = current_lr

        log_start(key=constants.EPOCH_START,
                  metadata={
                      'epoch_num': epoch + 1,
                      'current_iter_num': iter_num
                  })

        for i, (img, bbox, label) in enumerate(train_loader):

            if args.profile_start is not None and iter_num == args.profile_start:
                torch.cuda.profiler.start()
                torch.cuda.synchronize()
                if args.profile_nvtx:
                    torch.autograd._enable_profiler(
                        torch.autograd.ProfilerState.NVTX)

            if args.profile is not None and iter_num == args.profile:
                if args.profile_start is not None and iter_num >= args.profile_start:
                    # we turned cuda and nvtx profiling on, better turn it off too
                    if args.profile_nvtx:
                        torch.autograd._disable_profiler()
                    torch.cuda.profiler.stop()
                return

            if args.warmup is not None:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr,
                          args)

            if (img is None) or (bbox is None) or (label is None):
                print("No labels in batch")
                continue

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            N = img.shape[0]
            bbox.requires_grad = False
            label.requires_grad = False
            # reshape (N*8732X4 -> Nx8732x4) and transpose (Nx8732x4 -> Nx4x8732)
            bbox = bbox.view(N, -1, 4).transpose(1, 2).contiguous()
            # reshape (N*8732 -> Nx8732) and cast to Long
            label = label.view(N, -1).long()
            loss = loss_func(ploc, plabel, bbox, label)

            if np.isfinite(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            else:
                print("model exploded (corrupted by Inf or Nan)")
                sys.exit()

            num_elapsed_samples += N
            if args.rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * args.N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            with apex.amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()

            if not args.profile_fake_optim:
                optim.step()

            # Likely a decent skew here, let's take this opportunity to set the
            # gradients to None.  After DALI integration, playing with the
            # placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            # Don't check every iteration due to cost of broadcast
            if iter_num % 20 == 0:
                finished = check_async_evals(args, evaluator, args.threshold)

                if finished:
                    return True

            iter_num += 1

        train_loader.reset()
        log_end(key=constants.EPOCH_STOP, metadata={'epoch_num': epoch + 1})

    return False
Esempio n. 2
0
def train300_mlperf_coco(args):
    from pycocotools.coco import COCO

    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    local_seed = set_seeds(args)
    # start timing here
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    validate_group_bn(args.bn_group)
    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    input_size = 300
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    # Build the model
    model_options = {
        'backbone': args.backbone,
        'use_nhwc': args.nhwc,
        'pad_input': args.pad_input,
        'bn_group': args.bn_group,
    }

    ssd300 = SSD300(args.num_classes, **model_options)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    if args.opt_loss:
        loss_func = OptLoss(dboxes)
    else:
        loss_func = Loss(dboxes)
    loss_func.cuda()

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank,
                          "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     gradient_predivide_factor=N_gpu / 8.0,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (N_gpu * args.batch_size)
    mlperf_print(key=mlperf_compliance.constants.MODEL_BN_SPAN,
                 value=args.bn_group * args.batch_size)
    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE,
                 value=global_batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 2.5e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(
        1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = args.wd
    static_loss_scale = 128.
    if args.use_fp16:
        if args.distributed and not args.delay_allreduce:
            # We can't create the flat master params yet, because we need to
            # imitate the flattened bucket structure that DDP produces.
            optimizer_created = False
        else:
            model_buckets = [
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.HalfTensor"
                ],
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.FloatTensor"
                ]
            ]
            flat_master_buckets = create_flat_master(model_buckets)
            optim = torch.optim.SGD(flat_master_buckets,
                                    lr=current_lr,
                                    momentum=current_momentum,
                                    weight_decay=current_weight_decay)
            optimizer_created = True
    else:
        optim = torch.optim.SGD(ssd300.parameters(),
                                lr=current_lr,
                                momentum=current_momentum,
                                weight_decay=current_weight_decay)
        optimizer_created = True

    mlperf_print(key=mlperf_compliance.constants.OPT_BASE_LR, value=current_lr)
    mlperf_print(key=mlperf_compliance.constants.OPT_WEIGHT_DECAY,
                 value=current_weight_decay)
    if args.warmup is not None:
        mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_STEPS,
                     value=args.warmup)
        mlperf_print(key=mlperf_compliance.constants.OPT_LR_WARMUP_FACTOR,
                     value=args.warmup_factor)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(args.num_classes,
                         backbone=args.backbone,
                         use_nhwc=args.nhwc,
                         pad_input=args.pad_input).cuda()
    if args.use_fp16:
        ssd300_eval = network_to_half(ssd300_eval)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()

    print_message(args.local_rank, "epoch", "nbatch", "loss")
    eval_points = np.array(args.evaluation) * 32 / global_batch_size
    eval_points = list(map(int, list(eval_points)))

    iter_num = args.iteration
    avg_loss = 0.0

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    # Generate normalization tensors
    mean, std = generate_mean_std(args)

    dummy_overflow_buf = torch.cuda.IntTensor([0])

    def step_maybe_fp16_maybe_distributed(optim):
        if args.use_fp16:
            if args.distributed:
                for flat_master, allreduce_buffer in zip(
                        flat_master_buckets, ssd300.allreduce_buffers):
                    if allreduce_buffer is None:
                        raise RuntimeError("allreduce_buffer is None")
                    flat_master.grad = allreduce_buffer.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
            else:
                for flat_master, model_bucket in zip(flat_master_buckets,
                                                     model_buckets):
                    flat_grad = apex_C.flatten(
                        [m.grad.data for m in model_bucket])
                    flat_master.grad = flat_grad.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
        optim.step()
        if args.use_fp16:
            # Use multi-tensor scale instead of loop & individual parameter copies
            for model_bucket, flat_master in zip(model_buckets,
                                                 flat_master_buckets):
                multi_tensor_applier(
                    amp_C.multi_tensor_scale, dummy_overflow_buf, [
                        apex_C.unflatten(flat_master.data, model_bucket),
                        model_bucket
                    ], 1.0)

    input_c = 4 if args.pad_input else 3
    example_shape = [args.batch_size, 300, 300, input_c
                     ] if args.nhwc else [args.batch_size, input_c, 300, 300]
    example_input = torch.randn(*example_shape).cuda()

    if args.use_fp16:
        example_input = example_input.half()
    if args.jit:
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit, example_input)
        else:
            ssd300 = torch.jit.trace(module_to_jit, example_input)
        # JIT the eval model too
        ssd300_eval = torch.jit.trace(ssd300_eval, example_input)

    # do a dummy fprop & bprop to make sure cudnnFind etc. are timed here
    ploc, plabel = ssd300(example_input)

    # produce a single dummy "loss" to make things easier
    loss = ploc[0, 0, 0] + plabel[0, 0, 0]
    dloss = torch.randn_like(loss)
    # Cause cudnnFind for dgrad, wgrad to run
    loss.backward(dloss)

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP, sync=True)
    ##### END INIT

    # This is the first place we touch anything related to data
    ##### START DATA TOUCHING
    mlperf_print(key=mlperf_compliance.constants.RUN_START, sync=True)
    barrier()
    cocoGt = COCO(annotation_file=val_annotate, use_ext=True)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)

    if args.distributed:
        val_sampler = GeneralDistributedSampler(val_coco, pad=False)
    else:
        val_sampler = None

    if args.no_dali:
        train_trans = SSDTransformer(dboxes, (input_size, input_size),
                                     val=False)
        train_coco = COCODetection(train_coco_root, train_annotate,
                                   train_trans)

        if args.distributed:
            train_sampler = GeneralDistributedSampler(train_coco, pad=False)
        else:
            train_sampler = None

        train_loader = DataLoader(train_coco,
                                  batch_size=args.batch_size *
                                  args.input_batch_multiplier,
                                  shuffle=(train_sampler is None),
                                  sampler=train_sampler,
                                  num_workers=args.num_workers,
                                  collate_fn=partial(my_collate,
                                                     is_training=True))
    else:
        train_pipe = COCOPipeline(args.batch_size *
                                  args.input_batch_multiplier,
                                  args.local_rank,
                                  train_coco_root,
                                  train_annotate,
                                  N_gpu,
                                  num_threads=args.num_workers,
                                  output_fp16=args.use_fp16,
                                  output_nhwc=args.nhwc,
                                  pad_output=args.pad_input,
                                  seed=local_seed - 2**31,
                                  use_nvjpeg=args.use_nvjpeg,
                                  use_roi=args.use_roi_decode,
                                  dali_cache=args.dali_cache,
                                  dali_async=(not args.dali_sync))
        print_message(args.local_rank,
                      "time_check a: {secs:.9f}".format(secs=time.time()))
        train_pipe.build()
        print_message(args.local_rank,
                      "time_check b: {secs:.9f}".format(secs=time.time()))
        test_run = train_pipe.run()
        train_loader = SingleDaliIterator(
            train_pipe, [
                'images',
                DALIOutput('bboxes', False, True),
                DALIOutput('labels', True, True)
            ],
            train_pipe.epoch_size()['train_reader'],
            ngpu=N_gpu)

    train_loader = EncodingInputIterator(train_loader,
                                         dboxes=encoder.dboxes.cuda(),
                                         nhwc=args.nhwc,
                                         fake_input=args.fake_input,
                                         no_dali=args.no_dali)
    if args.input_batch_multiplier > 1:
        train_loader = RateMatcher(input_it=train_loader,
                                   output_size=args.batch_size)

    val_dataloader = DataLoader(
        val_coco,
        batch_size=args.eval_batch_size,
        shuffle=False,  # Note: distributed sampler is shuffled :(
        sampler=val_sampler,
        num_workers=args.num_workers)

    inv_map = {v: k for k, v in val_coco.label_map.items()}

    ##### END DATA TOUCHING
    i_eval = 0
    first_epoch = 1
    mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                 metadata={
                     'first_epoch_num':
                     first_epoch,
                     'epoch_count':
                     args.evaluation[i_eval] * 32 /
                     train_pipe.epoch_size()['train_reader']
                 },
                 sync=True)
    for epoch in range(args.epochs):
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)
        for p in ssd300.parameters():
            p.grad = None

        for i, (img, bbox, label) in enumerate(train_loader):

            if args.profile_start is not None and iter_num == args.profile_start:
                torch.cuda.profiler.start()
                torch.cuda.synchronize()
                if args.profile_nvtx:
                    torch.autograd._enable_profiler(
                        torch.autograd.ProfilerState.NVTX)

            if args.profile is not None and iter_num == args.profile:
                if args.profile_start is not None and iter_num >= args.profile_start:
                    # we turned cuda and nvtx profiling on, better turn it off too
                    if args.profile_nvtx:
                        torch.autograd._disable_profiler()
                    torch.cuda.profiler.stop()
                return

            if args.warmup is not None and optimizer_created:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr,
                          args)
            if iter_num == ((args.decay1 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #1")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr

            if iter_num == ((args.decay2 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #2")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr

            if (img is None) or (bbox is None) or (label is None):
                print("No labels in batch")
                continue

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            N = img.shape[0]
            gloc, glabel = Variable(bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if np.isfinite(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            else:
                print("model exploded (corrupted by Inf or Nan)")
                sys.exit()

            num_elapsed_samples += N
            if args.local_rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            # loss scaling
            if args.use_fp16:
                loss = loss * static_loss_scale
            loss.backward()

            if not optimizer_created:
                # Imitate the model bucket structure created by DDP.
                # These will already be split by type (float or half).
                model_buckets = []
                for bucket in ssd300.active_i_buckets:
                    model_buckets.append([])
                    for active_i in bucket:
                        model_buckets[-1].append(
                            ssd300.active_params[active_i])
                flat_master_buckets = create_flat_master(model_buckets)
                optim = torch.optim.SGD(flat_master_buckets,
                                        lr=current_lr,
                                        momentum=current_momentum,
                                        weight_decay=current_weight_decay)
                optimizer_created = True
                # Skip this first iteration because flattened allreduce buffers are not yet created.
                # step_maybe_fp16_maybe_distributed(optim)
            else:
                step_maybe_fp16_maybe_distributed(optim)

            # Likely a decent skew here, let's take this opportunity to set the gradients to None.
            # After DALI integration, playing with the placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            if iter_num in eval_points:
                # Get the existant state from the train model
                # * if we use distributed, then we want .module
                train_model = ssd300.module if args.distributed else ssd300

                if args.distributed and args.allreduce_running_stats:
                    if get_rank() == 0:
                        print("averaging bn running means and vars")
                    # make sure every node has the same running bn stats before
                    # using them to evaluate, or saving the model for inference
                    world_size = float(torch.distributed.get_world_size())
                    for bn_name, bn_buf in train_model.named_buffers(
                            recurse=True):
                        if ('running_mean' in bn_name) or ('running_var'
                                                           in bn_name):
                            torch.distributed.all_reduce(bn_buf,
                                                         op=dist.ReduceOp.SUM)
                            bn_buf /= world_size

                if get_rank() == 0:
                    if not args.no_save:
                        print("saving model...")
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": val_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

                ssd300_eval.load_state_dict(train_model.state_dict())
                succ = coco_eval(
                    ssd300_eval,
                    val_dataloader,
                    cocoGt,
                    encoder,
                    inv_map,
                    args.threshold,
                    epoch,
                    iter_num,
                    args.eval_batch_size,
                    use_fp16=args.use_fp16,
                    local_rank=args.local_rank if args.distributed else -1,
                    N_gpu=N_gpu,
                    use_nhwc=args.nhwc,
                    pad_input=args.pad_input)
                mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                             metadata={'first_epoch_num': first_epoch},
                             sync=True)
                if succ:
                    return True
                if iter_num != max(eval_points):
                    i_eval += 1
                    first_epoch = epoch + 1
                    mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                                 metadata={
                                     'first_epoch_num':
                                     first_epoch,
                                     'epoch_count':
                                     (args.evaluation[i_eval] -
                                      args.evaluation[i_eval - 1]) * 32 /
                                     train_pipe.epoch_size()['train_reader']
                                 },
                                 sync=True)
            iter_num += 1
            if args.max_iter > 0:
                if iter_num > args.max_iter:
                    break

        train_loader.reset()
        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)
    return False
def train300_mlperf_coco(args):
    from coco import COCO

    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    local_seed = set_seeds(args)
    # start timing here
    ssd_print(key=mlperf_log.RUN_START)
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    input_size = 300
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)
    ssd_print(key=mlperf_log.INPUT_SIZE, value=input_size)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)

    if args.distributed:
        val_sampler = GeneralDistributedSampler(val_coco, pad=False)
    else:
        val_sampler = None

    train_pipe = COCOPipeline(args.batch_size,
                              args.local_rank,
                              train_coco_root,
                              train_annotate,
                              N_gpu,
                              num_threads=args.num_workers,
                              output_fp16=args.use_fp16,
                              output_nhwc=args.nhwc,
                              pad_output=args.pad_input,
                              seed=local_seed - 2**31)
    print_message(args.local_rank,
                  "time_check a: {secs:.9f}".format(secs=time.time()))
    train_pipe.build()
    print_message(args.local_rank,
                  "time_check b: {secs:.9f}".format(secs=time.time()))
    test_run = train_pipe.run()
    train_loader = DALICOCOIterator(train_pipe, 118287 / N_gpu)

    val_dataloader = DataLoader(
        val_coco,
        batch_size=args.eval_batch_size,
        shuffle=False,  # Note: distributed sampler is shuffled :(
        sampler=val_sampler,
        num_workers=args.num_workers)
    ssd_print(key=mlperf_log.INPUT_ORDER)
    ssd_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size)
    # Build the model
    ssd300 = SSD300(val_coco.labelnum,
                    backbone=args.backbone,
                    use_nhwc=args.nhwc,
                    pad_input=args.pad_input)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    loss_func = Loss(dboxes)
    loss_func.cuda()

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank,
                          "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (N_gpu * args.batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 1e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(
        1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = 5e-4
    static_loss_scale = 128.
    if args.use_fp16:
        if args.distributed and not args.delay_allreduce:
            # We can't create the flat master params yet, because we need to
            # imitate the flattened bucket structure that DDP produces.
            optimizer_created = False
        else:
            model_buckets = [
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.HalfTensor"
                ],
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.FloatTensor"
                ]
            ]
            flat_master_buckets = create_flat_master(model_buckets)
            optim = torch.optim.SGD(flat_master_buckets,
                                    lr=current_lr,
                                    momentum=current_momentum,
                                    weight_decay=current_weight_decay)
            optimizer_created = True
    else:
        optim = torch.optim.SGD(ssd300.parameters(),
                                lr=current_lr,
                                momentum=current_momentum,
                                weight_decay=current_weight_decay)
        optimizer_created = True

    # Add LARC if desired
    if args.use_larc:
        optim = LARC(optim)

    ssd_print(key=mlperf_log.OPT_NAME, value="SGD")
    ssd_print(key=mlperf_log.OPT_LR, value=current_lr)
    ssd_print(key=mlperf_log.OPT_MOMENTUM, value=current_momentum)
    ssd_print(key=mlperf_log.OPT_WEIGHT_DECAY, value=current_weight_decay)
    if args.warmup is not None:
        ssd_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=args.warmup)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(val_coco.labelnum,
                         backbone=args.backbone,
                         use_nhwc=args.nhwc,
                         pad_input=args.pad_input).cuda()
    if args.use_fp16:
        ssd300_eval = network_to_half(ssd300_eval)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()

    if args.jit:
        input_c = 4 if args.pad_input else 3
        example_shape = [
            args.batch_size, 300, 300, input_c
        ] if args.nhwc else [args.batch_size, input_c, 300, 300]
        example_input = torch.randn(*example_shape).cuda()
        if args.use_fp16:
            example_input = example_input.half()
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit, example_input)
        else:
            ssd300 = torch.jit.trace(module_to_jit, example_input)

    print_message(args.local_rank, "epoch", "nbatch", "loss")
    eval_points = np.array(args.evaluation) * 32 / global_batch_size
    eval_points = list(map(int, list(eval_points)))

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v: k for k, v in val_coco.label_map.items()}

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    # Generate normalization tensors
    mean, std = generate_mean_std(args)

    def step_maybe_fp16_maybe_distributed(optim):
        if args.use_fp16:
            if args.distributed:
                for flat_master, allreduce_buffer in zip(
                        flat_master_buckets, ssd300.allreduce_buffers):
                    if allreduce_buffer is None:
                        raise RuntimeError("allreduce_buffer is None")
                    flat_master.grad = allreduce_buffer.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
            else:
                for flat_master, model_bucket in zip(flat_master_buckets,
                                                     model_buckets):
                    flat_grad = apex_C.flatten(
                        [m.grad.data for m in model_bucket])
                    flat_master.grad = flat_grad.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
        optim.step()
        if args.use_fp16:
            for model_bucket, flat_master in zip(model_buckets,
                                                 flat_master_buckets):
                for model, master in zip(
                        model_bucket,
                        apex_C.unflatten(flat_master.data, model_bucket)):
                    model.data.copy_(master.data)

    ssd_print(key=mlperf_log.TRAIN_LOOP)
    for epoch in range(args.epochs):
        ssd_print(key=mlperf_log.TRAIN_EPOCH, value=epoch)
        for p in ssd300.parameters():
            p.grad = None
        for i, data in enumerate(train_loader):
            img = data[0][0][0]
            bbox = data[0][1][0]
            label = data[0][2][0]
            label = label.type(torch.cuda.LongTensor)
            bbox_offsets = data[0][3][0]

            # handle random flipping outside of DALI for now
            bbox_offsets = bbox_offsets.cuda()
            img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5,
                                            args.nhwc)
            img.sub_(mean).div_(std)

            if args.profile is not None and iter_num == args.profile:
                return
            if args.warmup is not None and optimizer_created:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr,
                          args)
            if iter_num == ((args.decay1 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #1")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if iter_num == ((args.decay2 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #2")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if use_cuda:
                img = img.cuda()
                # NHWC direct from DALI now if necessary
                bbox = bbox.cuda()
                label = label.cuda()
                bbox_offsets = bbox_offsets.cuda()

            # Now run the batched box encoder
            N = img.shape[0]
            if bbox_offsets[-1].item() == 0:
                print("No labels in batch")
                continue
            bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                        encoder.dboxes.cuda(), 0.5)

            # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
            M = bbox.shape[0] // N
            bbox = bbox.view(N, M, 4)
            label = label.view(N, M)
            # print(img.shape, bbox.shape, label.shape)

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            trans_bbox = bbox.transpose(1, 2).contiguous().cuda()
            label = label.cuda()
            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()

            num_elapsed_samples += N
            if args.local_rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            # loss scaling
            if args.use_fp16:
                loss = loss * static_loss_scale
            loss.backward()

            if not optimizer_created:
                # Imitate the model bucket structure created by DDP.
                # These will already be split by type (float or half).
                model_buckets = []
                for bucket in ssd300.active_i_buckets:
                    model_buckets.append([])
                    for active_i in bucket:
                        model_buckets[-1].append(
                            ssd300.active_params[active_i])
                flat_master_buckets = create_flat_master(model_buckets)
                optim = torch.optim.SGD(flat_master_buckets,
                                        lr=current_lr,
                                        momentum=current_momentum,
                                        weight_decay=current_weight_decay)
                optimizer_created = True
                # Skip this first iteration because flattened allreduce buffers are not yet created.
                # step_maybe_fp16_maybe_distributed(optim)
            else:
                step_maybe_fp16_maybe_distributed(optim)

            # Likely a decent skew here, let's take this opportunity to set the gradients to None.
            # After DALI integration, playing with the placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            if iter_num in eval_points:
                if args.local_rank == 0:
                    if not args.no_save:
                        print("saving model...")
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": val_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

# Get the existant state from the train model
# * if we use distributed, then we want .module
                train_model = ssd300.module if args.distributed else ssd300

                ssd300_eval.load_state_dict(train_model.state_dict())
                if coco_eval(
                        ssd300_eval,
                        val_dataloader,
                        cocoGt,
                        encoder,
                        inv_map,
                        args.threshold,
                        epoch,
                        iter_num,
                        args.eval_batch_size,
                        use_fp16=args.use_fp16,
                        local_rank=args.local_rank if args.distributed else -1,
                        N_gpu=N_gpu,
                        use_nhwc=args.nhwc,
                        pad_input=args.pad_input):
                    return True

            iter_num += 1

        train_loader.reset()
    return False