def train_loop(model, loss_func, epoch, optim, train_dataloader, val_dataloader, encoder, iteration, logger, args, mean, std):
#     for nbatch, (img, _, img_size, bbox, label) in enumerate(train_dataloader):
    for nbatch, data in enumerate(train_dataloader):
        img = data[0][0][0]
        bbox = data[0][1][0]
        label = data[0][2][0]
        label = label.type(torch.cuda.LongTensor)
        bbox_offsets = data[0][3][0]
        # handle random flipping outside of DALI for now
        bbox_offsets = bbox_offsets.cuda()
        img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5, False)
        img.sub_(mean).div_(std)
        if not args.no_cuda:
            img = img.cuda()
            bbox = bbox.cuda()
            label = label.cuda()
            bbox_offsets = bbox_offsets.cuda()

        N = img.shape[0]
        if bbox_offsets[-1].item() == 0:
            print("No labels in batch")
            continue
        bbox, label = C.box_encoder(N, bbox, bbox_offsets, label, encoder.dboxes.cuda(), 0.5)
        # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
        M = bbox.shape[0] // N
        bbox = bbox.view(N, M, 4)
        label = label.view(N, M)

        ploc, plabel = model(img)
        ploc, plabel = ploc.float(), plabel.float()

        trans_bbox = bbox.transpose(1, 2).contiguous().cuda()

        if not args.no_cuda:
            label = label.cuda()
        gloc = Variable(trans_bbox, requires_grad=False)
        glabel = Variable(label, requires_grad=False)

        loss = loss_func(ploc, plabel, gloc, glabel)

        if args.local_rank == 0:
            logger.update_iter(epoch, iteration, loss.item())

        if args.fp16:
            if args.amp:
                with optim.scale_loss(loss) as scale_loss:
                    scale_loss.backward()
            else:
                optim.backward(loss)
        else:
            loss.backward()

        if args.warmup is not None:
            warmup(optim, args.warmup, iteration, args.learning_rate)

        optim.step()
        optim.zero_grad()
        iteration += 1

    return iteration
Beispiel #2
0
    def encode(self, batch_size, bboxes, bboxes_offset, labels_in):
        bbox, label = C.box_encoder(batch_size, bboxes, bboxes_offset, labels_in, self.dboxes.cuda(), 0.5)
        # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
        M = bbox.shape[0] // batch_size
        bboxes_out = bbox.view(batch_size, M, 4)
        labels_out = label.view(batch_size, M)

        x, y, w, h = bboxes_out[:, :, 0], bboxes_out[:, :, 1], bboxes_out[:, :, 2], bboxes_out[:, :, 3]

        bboxes_out[:, :, 0] = (x - self.batch_dboxes[:, :, 0]) / (self.scale_xy * self.batch_dboxes[:, :, 2])
        bboxes_out[:, :, 1] = (y - self.batch_dboxes[:, :, 1]) / (self.scale_xy * self.batch_dboxes[:, :, 3])
        bboxes_out[:, :, 2] = torch.log(w / self.batch_dboxes[:, :, 2]) / self.scale_wh
        bboxes_out[:, :, 3] = torch.log(h / self.batch_dboxes[:, :, 3]) / self.scale_wh

        return bboxes_out, labels_out
Beispiel #3
0
    def __next__(self):
        # special case for fake_input for all iterations after first
        if self._saved_batch is not None:
            return self._saved_batch

        (img, bbox, label, bbox_offsets) = self._input_it.__next__()

        # non-dali path is cpu, move tensors to gpu
        if self._no_dali:
            img = img.cuda()
            bbox = bbox.cuda()
            label = label.cuda()
            bbox_offsets = bbox_offsets.cuda()

        if bbox_offsets[-1].item() == 0:
            img = None
            bbox = None
            label = None
        else:
            # dali path doesn't do horizontal flip, do it here
            if not self._no_dali:
                img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5,
                                                self._nhwc)

            # massage raw ground truth into form used by loss function
            bbox, label = C.box_encoder(
                img.shape[0],  # <- batch size
                bbox,
                bbox_offsets,
                label.type(torch.cuda.LongTensor),
                self._dboxes,
                0.5)
            bbox = bbox.transpose(1, 2).contiguous().cuda()
            label = label.cuda()

        if self._fake_input:
            self._saved_batch = (img, bbox, label)

        return (img, bbox, label)
Beispiel #4
0
def benchmark_train_loop(model, loss_func, epoch, optim, train_dataloader,
                         val_dataloader, encoder, iteration, logger, args,
                         mean, std):
    start_time = None
    # tensor for results
    result = torch.zeros((1, )).cuda()
    for i, data in enumerate(loop(train_dataloader)):
        if i >= args.benchmark_warmup:
            torch.cuda.synchronize()
            start_time = time.time()

        img = data[0][0][0]
        bbox = data[0][1][0]
        label = data[0][2][0]
        label = label.type(torch.cuda.LongTensor)
        bbox_offsets = data[0][3][0]
        # handle random flipping outside of DALI for now
        bbox_offsets = bbox_offsets.cuda()
        img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5, False)

        if not args.no_cuda:
            img = img.cuda()
            bbox = bbox.cuda()
            label = label.cuda()
            bbox_offsets = bbox_offsets.cuda()
        img.sub_(mean).div_(std)

        N = img.shape[0]
        if bbox_offsets[-1].item() == 0:
            print("No labels in batch")
            continue
        bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                    encoder.dboxes.cuda(), 0.5)

        M = bbox.shape[0] // N
        bbox = bbox.view(N, M, 4)
        label = label.view(N, M)

        ploc, plabel = model(img)
        ploc, plabel = ploc.float(), plabel.float()

        trans_bbox = bbox.transpose(1, 2).contiguous().cuda()

        if not args.no_cuda:
            label = label.cuda()
        gloc = Variable(trans_bbox, requires_grad=False)
        glabel = Variable(label, requires_grad=False)

        loss = loss_func(ploc, plabel, gloc, glabel)

        # loss scaling
        if args.amp:
            with amp.scale_loss(loss, optim) as scale_loss:
                scale_loss.backward()
        else:
            loss.backward()

        optim.step()
        optim.zero_grad()

        if i >= args.benchmark_warmup + args.benchmark_iterations:
            break

        if i >= args.benchmark_warmup:
            torch.cuda.synchronize()
            logger.update(args.batch_size, time.time() - start_time)

    result.data[0] = logger.print_result()
    if args.N_gpu > 1:
        # torch.distributed.reduce(result, 0)
        herring.all_reduce(result)
    if args.local_rank == 0:
        print('Training performance = {} FPS'.format(float(result.data[0])))
Beispiel #5
0
def train_loop(model, loss_func, da_loss, epoch, optim, train_dataloader,
               target_dataloader, encoder, iteration, logger, args, mean, std,
               meters, vis):

    #     for nbatch, (img, _, img_size, bbox, label) in enumerate(train_dataloader):
    for nbatch, (source_data, target_data) in enumerate(
            zip(train_dataloader, target_dataloader)):
        target_img = target_data[0][0]
        img = source_data[0][0][0]
        bbox = source_data[0][1][0]
        label = source_data[0][2][0]
        # decode2(img, bbox)
        label = label.type(torch.cuda.LongTensor)
        batch_size = img.shape[0]
        source_domain_labels = torch.zeros(batch_size)
        # target_domain_labels = torch.ones(batch_size // 2)
        target_domain_labels = torch.ones(batch_size)
        bbox_offsets = source_data[0][3][0]
        # handle random flipping outside of DALI for now
        bbox_offsets = bbox_offsets.cuda()
        img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5, False)
        img.sub_(mean).div_(std)
        if not args.no_cuda:
            img = img.cuda()
            bbox = bbox.cuda()
            label = label.cuda()
            source_domain_labels = source_domain_labels.cuda()
            target_domain_labels = target_domain_labels.cuda()
            target_img = target_img.cuda()
            # bbox_offsets = bbox_offsets.cuda()

        domain_label = torch.cat([source_domain_labels, target_domain_labels],
                                 dim=0)
        images = torch.cat([img, target_img], dim=0)

        N = img.shape[0]
        if bbox_offsets[-1].item() == 0:
            print("No labels in batch")
            continue
        bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                    encoder.dboxes.cuda(), 0.5)
        # label = label * (label != OTHERS).long()
        # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
        M = bbox.shape[0] // N
        bbox = bbox.view(N, M, 4)
        label = label.view(N, M)

        # bbox = torch.cat([bbox, bbox.new_ones((N // 2,) + bbox.shape[1:])], dim=0)
        # label = torch.cat([label, label.new_ones((N // 2,) + label.shape[1:])], dim=0)

        bbox = torch.cat([bbox, bbox.new_ones((N, ) + bbox.shape[1:])], dim=0)
        label = torch.cat(
            [label, label.new_ones((N, ) + label.shape[1:])], dim=0)

        ploc, plabel, domain_classifier_features = model(images)
        ploc, plabel = ploc.float(), plabel.float()

        # decode(encoder.dboxes_xywh, images, bbox, label)

        trans_bbox = bbox.transpose(1, 2).contiguous().cuda()

        if not args.no_cuda:
            label = label.cuda()
        gloc = Variable(trans_bbox, requires_grad=False)
        glabel = Variable(label, requires_grad=False)

        ssd_loss = loss_func(ploc, plabel, gloc, glabel, domain_label)
        adaptation_loss = da_loss(domain_classifier_features, domain_label)

        # loss = ssd_loss + 10 * adaptation_loss
        loss = ssd_loss + adaptation_loss

        if args.amp:
            with amp.scale_loss(loss, optim) as scale_loss:
                scale_loss.backward()
        else:
            loss.backward()

        if args.warmup is not None:
            warmup(optim, args.warmup, iteration, args.learning_rate)

        optim.step()
        optim.zero_grad()

        if args.local_rank == 0:
            logger.update_iter(epoch, iteration, loss.item())
            meters['total'].add(loss.cpu().detach().numpy())
            meters['ssd'].add(ssd_loss.cpu().detach().numpy())
            meters['da'].add(adaptation_loss.cpu().detach().numpy())

            if (nbatch + 1) % args.plot_every == 0:
                # plot loss
                vis.plot_many({k: v.value()[0] for k, v in meters.items()})

        iteration += 1

    return iteration
Beispiel #6
0
def test_coco(args):
    # For testing purposes we have to use CUDA
    use_cuda = True

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)

        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)

    if args.use_train_dataset:
        annotate = os.path.join(args.data,
                                "annotations/instances_train2017.json")
        coco_root = os.path.join(args.data, "train2017")
        img_number = 118287
    else:
        annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
        coco_root = os.path.join(args.data, "val2017")
        img_number = 5000

    pipe = COCOPipeline(args.batch_size,
                        args.local_rank,
                        coco_root,
                        annotate,
                        N_gpu,
                        num_threads=args.num_workers)
    pipe.build()
    test_run = pipe.run()
    dataloader = DALICOCOIterator(pipe, img_number / N_gpu)

    # Build the model
    ssd300 = SSD300(81, backbone=args.backbone, model_path='', dilation=False)
    """
    # Note: args.checkpoint is required, so this can never be false
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)

        # remove proceeding 'module' from checkpoint
        model = od["model"]
        for k in list(model.keys()):
            if k.startswith('module.'):
                model[k[7:]] = model.pop(k)
        ssd300.load_state_dict(model)
    """

    ssd300.cuda()
    ssd300.eval()
    loss_func = Loss(dboxes)
    loss_func.cuda()

    # parallelize
    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    if args.use_train_dataset and args.local_rank == 0:
        print(
            'Image 000000320612.jpg is in fact PNG and it will cause fail if '
            + 'used with nvJPEGDecoder in coco_pipeline')

    for epoch in range(2):
        if epoch == 1 and args.local_rank == 0:
            print("Performance computation starts")
            s = time.time()
        for i, data in enumerate(dataloader):

            with torch.no_grad():
                # Get data from pipeline
                img = data[0][0][0]
                bbox = data[0][1][0]
                label = data[0][2][0]
                label = label.type(torch.cuda.LongTensor)
                bbox_offsets = data[0][3][0]
                bbox_offsets = bbox_offsets.cuda()

                # Encode labels
                N = img.shape[0]
                if bbox_offsets[-1].item() == 0:
                    print("No labels in batch")
                    continue
                bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                            encoder.dboxes.cuda(), 0.5)

                # Prepare tensors for computing loss
                M = bbox.shape[0] // N
                bbox = bbox.view(N, M, 4)
                label = label.view(N, M)
                trans_bbox = bbox.transpose(1, 2).contiguous()
                gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                               Variable(label, requires_grad=False)

                if args.use_fp16:
                    img = img.half()

                for _ in range(args.fbu):
                    ploc, plabel = ssd300(img)
                    ploc, plabel = ploc.float(), plabel.float()
                    loss = loss_func(ploc, plabel, gloc, glabel)

        if epoch == 1 and args.local_rank == 0:
            e = time.time()
            print("Performance achieved: {:.2f} img/sec".format(img_number /
                                                                (e - s)))

        dataloader.reset()
def test_box_encoder():
    torch.cuda.device(0)

    # np.random.seed(0)

    # source boxes
    box_list = []
    for _ in range(128):
        box_list.append(b1)
    N, bboxes_cat, offsets, bboxes = load_bboxes(box_list, True)
    # N, bboxes_cat, offsets, bboxes = load_bboxes([b1[:2,:], b1[:2,:]])

    print(N, bboxes_cat, offsets)

    label_numpy = np.random.randn(offsets[-1]) * 10
    labels = torch.tensor(label_numpy.astype(np.int64)).cuda()

    # target boxes are default boxes from SSD
    dboxes = dboxes300_coco()
    dboxes = torch.tensor(np.array(dboxes(order='ltrb')).astype(np.float32))

    # print(dboxes[:10, :])

    start = time.time()
    bbox_out, label_out = C.box_encoder(N, bboxes_cat, offsets, labels,
                                        dboxes.cuda(), 0.5)
    torch.cuda.synchronize()
    end = time.time()

    cuda_time = end - start

    # print('bbox_out: {}'.format(bbox_out.shape))
    # print(bbox_out.cpu())

    # print('label_out: {}'.format(label_out.shape))
    # print(label_out.cpu())

    # reference
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)

    labels_ref = torch.tensor(label_numpy.astype(np.int64))
    start = time.time()

    ref_boxes = []
    ref_labels = []
    for i, bbox in enumerate(bboxes):
        label_slice = labels_ref[offsets[i]:offsets[i + 1]]
        bbox_ref_out, label_ref_out = encoder.encode(bbox.cpu(),
                                                     label_slice.cpu(),
                                                     criteria=0.5)
        ref_boxes.append(bbox_ref_out)
        ref_labels.append(label_ref_out)
    end = time.time()
    ref_time = end - start

    ref_boxes = torch.cat(ref_boxes)
    ref_labels = torch.cat(ref_labels)

    # print('ref bbox: {}'.format(ref_boxes.shape))
    # print(bbox_ref_out)

    r = np.isclose(ref_boxes.numpy(), bbox_out.cpu().numpy())
    # r = np.isclose(bbox_ref_out.numpy(), bbox_out.cpu().numpy())

    num_fail = 0
    for i, res in enumerate(r):
        if not res.any():
            num_fail += 1
            print(i, res, ref_boxes[i, :], bbox_out[i, :])

    print('{} bboxes failed'.format(num_fail))

    label_out = label_out.cpu().numpy()
    torch.cuda.synchronize()
    # r2 = np.isclose(label_out, label_ref_out.cpu().numpy())
    r2 = np.isclose(label_out, ref_labels.cpu().numpy())
    num_fail = 0
    for i, res in enumerate(r2):
        if not res:
            num_fail += 1
            print('label: ', i, res, label_out[i], ref_labels.numpy()[i])

    print('{} labels failed'.format(num_fail))

    print('CUDA took {}, numpy took: {}'.format(cuda_time, ref_time))
def train300_mlperf_coco(args):
    from coco import COCO

    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    local_seed = set_seeds(args)
    # start timing here
    ssd_print(key=mlperf_log.RUN_START)
    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    input_size = 300
    val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)
    ssd_print(key=mlperf_log.INPUT_SIZE, value=input_size)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)

    if args.distributed:
        val_sampler = GeneralDistributedSampler(val_coco, pad=False)
    else:
        val_sampler = None

    train_pipe = COCOPipeline(args.batch_size,
                              args.local_rank,
                              train_coco_root,
                              train_annotate,
                              N_gpu,
                              num_threads=args.num_workers,
                              output_fp16=args.use_fp16,
                              output_nhwc=args.nhwc,
                              pad_output=args.pad_input,
                              seed=local_seed - 2**31)
    print_message(args.local_rank,
                  "time_check a: {secs:.9f}".format(secs=time.time()))
    train_pipe.build()
    print_message(args.local_rank,
                  "time_check b: {secs:.9f}".format(secs=time.time()))
    test_run = train_pipe.run()
    train_loader = DALICOCOIterator(train_pipe, 118287 / N_gpu)

    val_dataloader = DataLoader(
        val_coco,
        batch_size=args.eval_batch_size,
        shuffle=False,  # Note: distributed sampler is shuffled :(
        sampler=val_sampler,
        num_workers=args.num_workers)
    ssd_print(key=mlperf_log.INPUT_ORDER)
    ssd_print(key=mlperf_log.INPUT_BATCH_SIZE, value=args.batch_size)
    # Build the model
    ssd300 = SSD300(val_coco.labelnum,
                    backbone=args.backbone,
                    use_nhwc=args.nhwc,
                    pad_input=args.pad_input)
    if args.checkpoint is not None:
        load_checkpoint(ssd300, args.checkpoint)

    ssd300.train()
    ssd300.cuda()
    loss_func = Loss(dboxes)
    loss_func.cuda()

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    if args.use_fp16:
        ssd300 = network_to_half(ssd300)

    # Parallelize.  Need to do this after network_to_half.
    if args.distributed:
        if args.delay_allreduce:
            print_message(args.local_rank,
                          "Delaying allreduces to the end of backward()")
        ssd300 = DDP(ssd300,
                     delay_allreduce=args.delay_allreduce,
                     retain_allreduce_buffers=args.use_fp16)

    # Create optimizer.  This must also be done after network_to_half.
    global_batch_size = (N_gpu * args.batch_size)

    # mlperf only allows base_lr scaled by an integer
    base_lr = 1e-3
    requested_lr_multiplier = args.lr / base_lr
    adjusted_multiplier = max(
        1, round(requested_lr_multiplier * global_batch_size / 32))

    current_lr = base_lr * adjusted_multiplier
    current_momentum = 0.9
    current_weight_decay = 5e-4
    static_loss_scale = 128.
    if args.use_fp16:
        if args.distributed and not args.delay_allreduce:
            # We can't create the flat master params yet, because we need to
            # imitate the flattened bucket structure that DDP produces.
            optimizer_created = False
        else:
            model_buckets = [
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.HalfTensor"
                ],
                [
                    p for p in ssd300.parameters()
                    if p.requires_grad and p.type() == "torch.cuda.FloatTensor"
                ]
            ]
            flat_master_buckets = create_flat_master(model_buckets)
            optim = torch.optim.SGD(flat_master_buckets,
                                    lr=current_lr,
                                    momentum=current_momentum,
                                    weight_decay=current_weight_decay)
            optimizer_created = True
    else:
        optim = torch.optim.SGD(ssd300.parameters(),
                                lr=current_lr,
                                momentum=current_momentum,
                                weight_decay=current_weight_decay)
        optimizer_created = True

    # Add LARC if desired
    if args.use_larc:
        optim = LARC(optim)

    ssd_print(key=mlperf_log.OPT_NAME, value="SGD")
    ssd_print(key=mlperf_log.OPT_LR, value=current_lr)
    ssd_print(key=mlperf_log.OPT_MOMENTUM, value=current_momentum)
    ssd_print(key=mlperf_log.OPT_WEIGHT_DECAY, value=current_weight_decay)
    if args.warmup is not None:
        ssd_print(key=mlperf_log.OPT_LR_WARMUP_STEPS, value=args.warmup)

    # Model is completely finished -- need to create separate copies, preserve parameters across
    # them, and jit
    ssd300_eval = SSD300(val_coco.labelnum,
                         backbone=args.backbone,
                         use_nhwc=args.nhwc,
                         pad_input=args.pad_input).cuda()
    if args.use_fp16:
        ssd300_eval = network_to_half(ssd300_eval)

    # Get the existant state from the train model
    # * if we use distributed, then we want .module
    train_model = ssd300.module if args.distributed else ssd300

    ssd300_eval.load_state_dict(train_model.state_dict())

    ssd300_eval.eval()

    if args.jit:
        input_c = 4 if args.pad_input else 3
        example_shape = [
            args.batch_size, 300, 300, input_c
        ] if args.nhwc else [args.batch_size, input_c, 300, 300]
        example_input = torch.randn(*example_shape).cuda()
        if args.use_fp16:
            example_input = example_input.half()
        # DDP has some Python-side control flow.  If we JIT the entire DDP-wrapped module,
        # the resulting ScriptModule will elide this control flow, resulting in allreduce
        # hooks not being called.  If we're running distributed, we need to extract and JIT
        # the wrapped .module.
        # Replacing a DDP-ed ssd300 with a script_module might also cause the AccumulateGrad hooks
        # to go out of scope, and therefore silently disappear.
        module_to_jit = ssd300.module if args.distributed else ssd300
        if args.distributed:
            ssd300.module = torch.jit.trace(module_to_jit, example_input)
        else:
            ssd300 = torch.jit.trace(module_to_jit, example_input)

    print_message(args.local_rank, "epoch", "nbatch", "loss")
    eval_points = np.array(args.evaluation) * 32 / global_batch_size
    eval_points = list(map(int, list(eval_points)))

    iter_num = args.iteration
    avg_loss = 0.0
    inv_map = {v: k for k, v in val_coco.label_map.items()}

    start_elapsed_time = time.time()
    last_printed_iter = args.iteration
    num_elapsed_samples = 0

    # Generate normalization tensors
    mean, std = generate_mean_std(args)

    def step_maybe_fp16_maybe_distributed(optim):
        if args.use_fp16:
            if args.distributed:
                for flat_master, allreduce_buffer in zip(
                        flat_master_buckets, ssd300.allreduce_buffers):
                    if allreduce_buffer is None:
                        raise RuntimeError("allreduce_buffer is None")
                    flat_master.grad = allreduce_buffer.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
            else:
                for flat_master, model_bucket in zip(flat_master_buckets,
                                                     model_buckets):
                    flat_grad = apex_C.flatten(
                        [m.grad.data for m in model_bucket])
                    flat_master.grad = flat_grad.float()
                    flat_master.grad.data.mul_(1. / static_loss_scale)
        optim.step()
        if args.use_fp16:
            for model_bucket, flat_master in zip(model_buckets,
                                                 flat_master_buckets):
                for model, master in zip(
                        model_bucket,
                        apex_C.unflatten(flat_master.data, model_bucket)):
                    model.data.copy_(master.data)

    ssd_print(key=mlperf_log.TRAIN_LOOP)
    for epoch in range(args.epochs):
        ssd_print(key=mlperf_log.TRAIN_EPOCH, value=epoch)
        for p in ssd300.parameters():
            p.grad = None
        for i, data in enumerate(train_loader):
            img = data[0][0][0]
            bbox = data[0][1][0]
            label = data[0][2][0]
            label = label.type(torch.cuda.LongTensor)
            bbox_offsets = data[0][3][0]

            # handle random flipping outside of DALI for now
            bbox_offsets = bbox_offsets.cuda()
            img, bbox = C.random_horiz_flip(img, bbox, bbox_offsets, 0.5,
                                            args.nhwc)
            img.sub_(mean).div_(std)

            if args.profile is not None and iter_num == args.profile:
                return
            if args.warmup is not None and optimizer_created:
                lr_warmup(optim, args.warmup, iter_num, epoch, current_lr,
                          args)
            if iter_num == ((args.decay1 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #1")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if iter_num == ((args.decay2 * 1000 * 32) // global_batch_size):
                print_message(args.local_rank, "lr decay step #2")
                current_lr *= 0.1
                for param_group in optim.param_groups:
                    param_group['lr'] = current_lr
                ssd_print(key=mlperf_log.OPT_LR, value=current_lr)

            if use_cuda:
                img = img.cuda()
                # NHWC direct from DALI now if necessary
                bbox = bbox.cuda()
                label = label.cuda()
                bbox_offsets = bbox_offsets.cuda()

            # Now run the batched box encoder
            N = img.shape[0]
            if bbox_offsets[-1].item() == 0:
                print("No labels in batch")
                continue
            bbox, label = C.box_encoder(N, bbox, bbox_offsets, label,
                                        encoder.dboxes.cuda(), 0.5)

            # output is ([N*8732, 4], [N*8732], need [N, 8732, 4], [N, 8732] respectively
            M = bbox.shape[0] // N
            bbox = bbox.view(N, M, 4)
            label = label.view(N, M)
            # print(img.shape, bbox.shape, label.shape)

            ploc, plabel = ssd300(img)
            ploc, plabel = ploc.float(), plabel.float()

            trans_bbox = bbox.transpose(1, 2).contiguous().cuda()
            label = label.cuda()
            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)
            loss = loss_func(ploc, plabel, gloc, glabel)

            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()

            num_elapsed_samples += N
            if args.local_rank == 0 and iter_num % args.print_interval == 0:
                end_elapsed_time = time.time()
                elapsed_time = end_elapsed_time - start_elapsed_time

                avg_samples_per_sec = num_elapsed_samples * N_gpu / elapsed_time

                print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
                            .format(iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")

                last_printed_iter = iter_num
                start_elapsed_time = time.time()
                num_elapsed_samples = 0

            # loss scaling
            if args.use_fp16:
                loss = loss * static_loss_scale
            loss.backward()

            if not optimizer_created:
                # Imitate the model bucket structure created by DDP.
                # These will already be split by type (float or half).
                model_buckets = []
                for bucket in ssd300.active_i_buckets:
                    model_buckets.append([])
                    for active_i in bucket:
                        model_buckets[-1].append(
                            ssd300.active_params[active_i])
                flat_master_buckets = create_flat_master(model_buckets)
                optim = torch.optim.SGD(flat_master_buckets,
                                        lr=current_lr,
                                        momentum=current_momentum,
                                        weight_decay=current_weight_decay)
                optimizer_created = True
                # Skip this first iteration because flattened allreduce buffers are not yet created.
                # step_maybe_fp16_maybe_distributed(optim)
            else:
                step_maybe_fp16_maybe_distributed(optim)

            # Likely a decent skew here, let's take this opportunity to set the gradients to None.
            # After DALI integration, playing with the placement of this is worth trying.
            for p in ssd300.parameters():
                p.grad = None

            if iter_num in eval_points:
                if args.local_rank == 0:
                    if not args.no_save:
                        print("saving model...")
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": val_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

# Get the existant state from the train model
# * if we use distributed, then we want .module
                train_model = ssd300.module if args.distributed else ssd300

                ssd300_eval.load_state_dict(train_model.state_dict())
                if coco_eval(
                        ssd300_eval,
                        val_dataloader,
                        cocoGt,
                        encoder,
                        inv_map,
                        args.threshold,
                        epoch,
                        iter_num,
                        args.eval_batch_size,
                        use_fp16=args.use_fp16,
                        local_rank=args.local_rank if args.distributed else -1,
                        N_gpu=N_gpu,
                        use_nhwc=args.nhwc,
                        pad_input=args.pad_input):
                    return True

            iter_num += 1

        train_loader.reset()
    return False