Ejemplo n.º 1
0
def train_mlperf_coco(args):
    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    ssd_r34 = SSD_R34(81, strides=args.strides)
    #img_size=[args.image_size,args.image_size]
    dboxes = dboxes_coco(args.image_size, args.strides)
    encoder = Encoder(dboxes)
    train_trans = SSDTransformer(dboxes, tuple(args.image_size), val=False)
    val_trans = SSDTransformer(dboxes, tuple(args.image_size), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    cocoGt = COCO(annotation_file=val_annotate)
    val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    train_coco = COCODetection(train_coco_root, train_annotate, train_trans)

    #print("Number of labels: {}".format(train_coco.labelnum))
    train_dataloader = DataLoader(train_coco,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)

    ssd_r34 = SSD_R34(train_coco.labelnum, strides=args.strides)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint)
        ssd_r34.load_state_dict(od["model"])
    ssd_r34.train()
    ssd_r34.to('cuda')
    if use_cuda:
        if args.device_ids and len(args.device_ids) > 1:
            ssd_r34 = nn.DataParallel(ssd_r34, args.device_ids)

    loss_func = Loss(dboxes)
    if use_cuda:
        loss_func.to('cuda')
        loss_func = nn.DataParallel(loss_func, args.device_ids)

    optim = torch.optim.SGD(ssd_r34.parameters(),
                            lr=1e-3,
                            momentum=0.9,
                            weight_decay=5e-4)
    print("epoch", "nbatch", "loss")

    iter_num = args.iteration
    avg_loss = 0.0
    last_loss = [0.0] * 10
    inv_map = {v: k for k, v in val_coco.label_map.items()}

    for epoch in range(args.epochs):

        for nbatch, (img, img_size, bbox,
                     label) in enumerate(train_dataloader):

            if iter_num == 160000:
                print("")
                print("lr decay step #1")
                for param_group in optim.param_groups:
                    param_group['lr'] = 1e-4

            if iter_num == 200000:
                print("")
                print("lr decay step #2")
                for param_group in optim.param_groups:
                    param_group['lr'] = 1e-5

            img = Variable(img, requires_grad=True)
            ploc, plabel, _ = ssd_r34(img.to('cuda'))
            trans_bbox = bbox.transpose(1, 2).contiguous()

            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                           Variable(label, requires_grad=False)

            loss = loss_func(ploc, plabel, gloc, glabel).mean()

            if not np.isinf(loss.item()):
                avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            last_loss.pop()
            last_loss = [loss.item()] + last_loss
            avg_last_loss = sum(last_loss) / len(last_loss)
            print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, Average Last 10 Loss: {:.3f}"\
                        .format(iter_num, loss.item(), avg_loss,avg_last_loss), end="\r")
            optim.zero_grad()
            loss.backward()
            optim.step()

            loss = None

            if iter_num in args.evaluation:
                if not args.no_save:
                    print("")
                    print("saving model...")
                    module = ssd_r34.module if len(
                        args.device_ids) > 1 else ssd_r34
                    torch.save(
                        {
                            "model": module.state_dict(),
                            "label_map": train_coco.label_info
                        }, args.save_path + "/iter_{}.pt".format(iter_num))
                if coco_eval(ssd_r34, val_coco, cocoGt, encoder, inv_map,
                             args.threshold, args.device_ids):
                    return

            iter_num += 1
Ejemplo n.º 2
0
def train300_mlperf_coco(args):
    global torch
    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.distributed = False
    if args.use_hpu:
        if 'WORLD_SIZE' in os.environ:
            args.distributed = int(os.environ['WORLD_SIZE']) > 1
            args.world_size = int(os.environ['WORLD_SIZE'])
            print("world_size = {}".format(args.world_size))
            print("distributed={}".format(args.distributed))
    if use_cuda:
        try:
            from apex.parallel import DistributedDataParallel as DDP
            if 'WORLD_SIZE' in os.environ:
                args.distributed = int(os.environ['WORLD_SIZE']) > 1
        except:
            raise ImportError(
                "Please install APEX from https://github.com/nvidia/apex")

    use_hpu = args.use_hpu
    hpu_channels_last = args.hpu_channels_last
    hpu_lazy_mode = args.hpu_lazy_mode
    is_hmp = args.is_hmp
    device = torch.device('cpu')
    data_loader_type = DataLoader
    if use_hpu:
        device = torch.device('hpu')
        if args.distributed:
            os.environ["MAX_WAIT_ATTEMPTS"] = "90"
        if hpu_lazy_mode:
            os.environ["PT_HPU_LAZY_MODE"] = "1"
        else:
            os.environ["PT_HPU_LAZY_MODE"] = "2"
        if is_hmp:
            if not args.hmp_bf16:
                raise IOError("Please provide list of BF16 ops")
            if not args.hmp_fp32:
                raise IOError("Please provide list of FP32 ops")
            from habana_frameworks.torch.hpex import hmp
            hmp.convert(opt_level=args.hmp_opt_level,
                        bf16_file_path=args.hmp_bf16,
                        fp32_file_path=args.hmp_fp32,
                        isVerbose=args.hmp_verbose)
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        # TODO - add dataloader

    local_seed = args.seed
    if args.distributed:
        # necessary pytorch imports
        import torch.utils.data.distributed
        import torch.distributed as dist
        if use_hpu:
            args.dist_backend = 'hccl'
            import habana_frameworks.torch.core.hccl
            os.environ["ID"] = os.environ["RANK"]
            dist.init_process_group(args.dist_backend, init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device, use_hpu=True)
            local_seed = (args.seed + dist.get_rank()) % 2**32
        elif args.no_cuda:
            device = torch.device('cpu')
        else:
            torch.cuda.set_device(args.local_rank)
            device = torch.device('cuda')
            dist.init_process_group(backend='nccl', init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device)
            local_seed = (args.seed + dist.get_rank()) % 2**32
    mllogger.event(key=mllog_const.SEED, value=local_seed)
    torch.manual_seed(local_seed)
    np.random.seed(seed=local_seed)
    random.seed(local_seed)  # amorgenstern
    torch.cuda.manual_seed(local_seed)  # amorgenstern

    args.rank = dist.get_rank() if args.distributed else args.local_rank
    print("args.rank = {}".format(args.rank))
    print("local rank = {}".format(args.local_rank))
    print("distributed={}".format(args.distributed))

    if use_hpu and is_hmp:
        with hmp.disable_casts():
            dboxes = dboxes300_coco()
            encoder = Encoder(dboxes)
    else:
        dboxes = dboxes300_coco()
        encoder = Encoder(dboxes)

    input_size = 300
    if use_hpu and is_hmp:
        with hmp.disable_casts():
            train_trans = SSDTransformer(
                dboxes, (input_size, input_size),
                val=False,
                num_cropping_iterations=args.num_cropping_iterations)
            val_trans = SSDTransformer(dboxes, (input_size, input_size),
                                       val=True)
    else:
        train_trans = SSDTransformer(
            dboxes, (input_size, input_size),
            val=False,
            num_cropping_iterations=args.num_cropping_iterations)
        val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    if use_hpu and is_hmp:
        with hmp.disable_casts():
            cocoGt = COCO(annotation_file=val_annotate)
            train_coco = COCODetection(train_coco_root, train_annotate,
                                       train_trans)
            val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    else:
        cocoGt = COCO(annotation_file=val_annotate)
        train_coco = COCODetection(train_coco_root, train_annotate,
                                   train_trans)
        val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    mllogger.event(key=mllog_const.TRAIN_SAMPLES, value=len(train_coco))
    mllogger.event(key=mllog_const.EVAL_SAMPLES, value=len(val_coco))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_coco)
    else:
        train_sampler = None
    if use_hpu:
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None
    train_dataloader = data_loader_type(train_coco,
                                        batch_size=args.batch_size,
                                        shuffle=(train_sampler is None),
                                        sampler=train_sampler,
                                        num_workers=args.num_workers)
    # set shuffle=True in DataLoader
    if args.rank == 0:
        val_dataloader = data_loader_type(val_coco,
                                          batch_size=args.val_batch_size
                                          or args.batch_size,
                                          shuffle=False,
                                          sampler=None,
                                          num_workers=args.num_workers)
    else:
        val_dataloader = None

    ssd300 = SSD300(train_coco.labelnum, model_path=args.pretrained_backbone)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint, map_location=torch.device('cpu'))
        ssd300.load_state_dict(od["model"])
    ssd300.train()
    if use_cuda:
        ssd300.cuda()
    if use_hpu and is_hmp:
        with hmp.disable_casts():
            loss_func = Loss(dboxes, use_hpu=use_hpu, hpu_device=device)
    else:
        loss_func = Loss(dboxes, use_hpu=use_hpu, hpu_device=device)
    if use_cuda:
        loss_func.cuda()

    if use_hpu:
        ssd300.to(device)
        loss_func.to(device)

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    global_batch_size = N_gpu * args.batch_size
    mllogger.event(key=mllog_const.GLOBAL_BATCH_SIZE, value=global_batch_size)
    # Reference doesn't support group batch norm, so bn_span==local_batch_size
    mllogger.event(key=mllog_const.MODEL_BN_SPAN, value=args.batch_size)
    current_lr = args.lr * (global_batch_size / 32)

    assert args.batch_size % args.batch_splits == 0, "--batch-size must be divisible by --batch-splits"
    fragment_size = args.batch_size // args.batch_splits
    if args.batch_splits != 1:
        print("using gradient accumulation with fragments of size {}".format(
            fragment_size))

    current_momentum = 0.9
    sgd_optimizer = torch.optim.SGD
    if use_hpu and hpu_lazy_mode:
        from habana_frameworks.torch.hpex.optimizers import FusedSGD
        sgd_optimizer = FusedSGD
    optim = sgd_optimizer(ssd300.parameters(),
                          lr=current_lr,
                          momentum=current_momentum,
                          weight_decay=args.weight_decay)
    if use_hpu:
        permute_params(model=ssd300,
                       to_filters_last=True,
                       lazy_mode=hpu_lazy_mode)
        permute_momentum(optimizer=optim,
                         to_filters_last=True,
                         lazy_mode=hpu_lazy_mode)

    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_BASE_LR,
              value=current_lr)
    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_WEIGHT_DECAY,
              value=args.weight_decay)

    # parallelize
    if args.distributed:
        if use_hpu:
            ssd300 = torch.nn.parallel.DistributedDataParallel(
                ssd300,
                bucket_cap_mb=100,
                broadcast_buffers=False,
                gradient_as_bucket_view=True)
        else:
            ssd300 = DDP(ssd300)

    iter_num = args.iteration
    end_iter_num = args.end_iteration
    if end_iter_num:
        print("--end-iteration set to: {}".format(end_iter_num))
        assert end_iter_num > iter_num, "--end-iteration must have a value > --iteration"
    avg_loss = 0.0
    if use_hpu:
        loss_iter = list()
    inv_map = {v: k for k, v in val_coco.label_map.items()}
    success = torch.zeros(1)
    if use_cuda:
        success = success.cuda()
    if use_hpu:
        success = success.to(device)

    if args.warmup:
        nonempty_imgs = len(train_coco)
        wb = int(args.warmup * nonempty_imgs / (N_gpu * args.batch_size))
        ssd_print(device=device,
                  use_hpu=use_hpu,
                  key=mllog_const.OPT_LR_WARMUP_STEPS,
                  value=wb)
        warmup_step = lambda iter_num, current_lr: lr_warmup(
            optim, wb, iter_num, current_lr, args)
    else:
        warmup_step = lambda iter_num, current_lr: None

    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_LR_WARMUP_FACTOR,
              value=args.warmup_factor)
    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_LR_DECAY_BOUNDARY_EPOCHS,
              value=args.lr_decay_schedule)
    mllogger.start(key=mllog_const.BLOCK_START,
                   metadata={
                       mllog_const.FIRST_EPOCH_NUM: 1,
                       mllog_const.EPOCH_COUNT: args.epochs
                   })

    optim.zero_grad(set_to_none=True)
    if use_hpu:
        start = time.time()
    for epoch in range(args.epochs):
        mllogger.start(key=mllog_const.EPOCH_START,
                       metadata={mllog_const.EPOCH_NUM: epoch})
        # set the epoch for the sampler
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch in args.lr_decay_schedule:
            current_lr *= 0.1
            print("")
            print("lr decay step #{num}".format(
                num=args.lr_decay_schedule.index(epoch) + 1))
            for param_group in optim.param_groups:
                param_group['lr'] = current_lr

        for nbatch, (img, img_id, img_size, bbox,
                     label) in enumerate(train_dataloader):
            current_batch_size = img.shape[0]
            # Split batch for gradient accumulation
            img = torch.split(img, fragment_size)
            bbox = torch.split(bbox, fragment_size)
            label = torch.split(label, fragment_size)

            for (fimg, fbbox, flabel) in zip(img, bbox, label):
                current_fragment_size = fimg.shape[0]
                if not use_hpu:
                    trans_bbox = fbbox.transpose(1, 2).contiguous()
                if use_cuda:
                    fimg = fimg.cuda()
                    trans_bbox = trans_bbox.cuda()
                    flabel = flabel.cuda()
                if use_hpu:
                    fimg = fimg.to(device)
                    if hpu_channels_last:
                        fimg = fimg.contiguous(
                            memory_format=torch.channels_last)
                        if hpu_lazy_mode:
                            mark_step()
                    if is_hmp:
                        with hmp.disable_casts():
                            #TODO revert after SW-58188 is fixed
                            trans_bbox = fbbox.to(device).transpose(
                                1, 2).contiguous()
                            flabel = flabel.to(device)
                    else:
                        #TODO revert after SW-58188 is fixed
                        trans_bbox = fbbox.to(device).transpose(
                            1, 2).contiguous()
                        flabel = flabel.to(device)
                fimg = Variable(fimg, requires_grad=True)
                if args.lowp:  # amorgenstern
                    import lowp
                    with lowp.Lowp(mode='BF16',
                                   warn_patched=True,
                                   warn_not_patched=True):
                        ploc, plabel = ssd300(fimg)
                        gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                Variable(flabel, requires_grad=False)
                        loss = loss_func(ploc, plabel, gloc, glabel)
                else:
                    ploc, plabel = ssd300(fimg)
                    if use_hpu and is_hmp:
                        with hmp.disable_casts():
                            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                    Variable(flabel, requires_grad=False)
                            loss = loss_func(ploc.float(), plabel.float(),
                                             gloc, glabel)
                    else:
                        gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                Variable(flabel, requires_grad=False)
                        loss = loss_func(ploc, plabel, gloc, glabel)
                loss = loss * (current_fragment_size / current_batch_size
                               )  # weighted mean
                if use_hpu and hpu_lazy_mode and args.distributed:
                    mark_step()
                loss.backward()
                if use_hpu and hpu_lazy_mode:
                    mark_step()

            warmup_step(iter_num, current_lr)
            if use_hpu and is_hmp:
                with hmp.disable_casts():
                    optim.step()
            else:
                optim.step()
            optim.zero_grad(set_to_none=True)
            if use_hpu:
                loss_iter.append(loss.clone().detach())
            else:
                if not np.isinf(loss.item()):
                    avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            if use_hpu and hpu_lazy_mode:
                mark_step()
            if use_hpu:
                if args.log_interval and not iter_num % args.log_interval:
                    cur_loss = 0.0
                    for i, x in enumerate(loss_iter):
                        cur_loss = x.cpu().item()
                        if not np.isinf(cur_loss):
                            avg_loss = 0.999 * avg_loss + 0.001 * cur_loss
                    if args.rank == 0:
                        print("Rank: {:6d}, Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
                            .format(args.rank, iter_num, cur_loss, avg_loss))
                    loss_iter = list()
            else:
                if args.rank == 0 and args.log_interval and not iter_num % args.log_interval:
                    print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
                        .format(iter_num, loss.item(), avg_loss))
            iter_num += 1
            if use_hpu and iter_num == 50:
                start = time.time()
            if end_iter_num and iter_num >= end_iter_num:
                if use_hpu:
                    print("Training Ended, total time: {:.2f} s".format(
                        time.time() - start))
                break

        if (args.val_epochs and (epoch+1) in args.val_epochs) or \
           (args.val_interval and not (epoch+1) % args.val_interval):
            if args.distributed:
                world_size = float(dist.get_world_size())
                for bn_name, bn_buf in ssd300.module.named_buffers(
                        recurse=True):
                    if ('running_mean' in bn_name) or ('running_var'
                                                       in bn_name):
                        dist.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
                        bn_buf /= world_size
                        ssd_print(device=device,
                                  use_hpu=use_hpu,
                                  key=mllog_const.MODEL_BN_SPAN,
                                  value=bn_buf)
            if args.rank == 0:
                if use_hpu:
                    print("Training Ended, total time: {:.2f} s".format(
                        time.time() - start))
                if not args.no_save:
                    print("")
                    print("saving model...")
                    if use_hpu:
                        permute_params(model=ssd300,
                                       to_filters_last=False,
                                       lazy_mode=hpu_lazy_mode)
                        ssd300_copy = SSD300(
                            train_coco.labelnum,
                            model_path=args.pretrained_backbone)
                        if args.distributed:
                            ssd300_copy.load_state_dict(
                                ssd300.module.state_dict())
                        else:
                            ssd300_copy.load_state_dict(ssd300.state_dict())
                        torch.save(
                            {
                                "model": ssd300_copy.state_dict(),
                                "label_map": train_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))
                        permute_params(model=ssd300,
                                       to_filters_last=True,
                                       lazy_mode=hpu_lazy_mode)
                    else:
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": train_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

                if coco_eval(ssd300,
                             val_dataloader,
                             cocoGt,
                             encoder,
                             inv_map,
                             args.threshold,
                             epoch + 1,
                             iter_num,
                             log_interval=args.log_interval,
                             use_cuda=use_cuda,
                             use_hpu=use_hpu,
                             hpu_device=device,
                             is_hmp=is_hmp,
                             hpu_channels_last=hpu_channels_last,
                             hpu_lazy_mode=hpu_lazy_mode,
                             nms_valid_thresh=args.nms_valid_thresh):
                    success = torch.ones(1)
                    if use_cuda:
                        success = success.cuda()
                    if use_hpu:
                        success = success.to(device)
            if args.distributed:
                dist.broadcast(success, 0)
            if success[0]:
                return True
            mllogger.end(key=mllog_const.EPOCH_STOP,
                         metadata={mllog_const.EPOCH_NUM: epoch})
    mllogger.end(key=mllog_const.BLOCK_STOP,
                 metadata={
                     mllog_const.FIRST_EPOCH_NUM: 1,
                     mllog_const.EPOCH_COUNT: args.epochs
                 })

    return False