Exemplo n.º 1
0
    def setup_environment(self) -> None:
        # This function is used to load Habana libraries required for PyTorch
        # to register HPU as one of the available devices.
        load_habana_module()

        os.environ["ID"] = str(self.local_rank)
        if self._process_group_backend == "hccl":
            # this env is used in overrides to check the backend initiated
            os.environ["HCCL_DISTRIBUTED_BACKEND"] = str(1)
        super().setup_environment()
Exemplo n.º 2
0
    def __init__(
        self,
        device: _DEVICE = "hpu",
        accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
        checkpoint_io: Optional[HPUCheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
    ):

        if not _HPU_AVAILABLE:
            raise MisconfigurationException("`SingleHPUStrategy` requires HPU devices to run")

        # This function is used to load Habana libraries required for PyTorch
        # to register HPU as one of the available devices.
        load_habana_module()

        super().__init__(
            accelerator=accelerator,
            device=device,
            checkpoint_io=checkpoint_io or HPUCheckpointIO(),
            precision_plugin=precision_plugin,
        )
Exemplo n.º 3
0
def main():

    if args.dl_worker_type == "MP":
        try:
            # Default 'fork' doesn't work with synapse. Use 'forkserver' or 'spawn'
            torch.multiprocessing.set_start_method('spawn')
            #work around for multi-process data loading for single card training. for multi card, use habana_torch_dataloader
        except RuntimeError:
            pass
    elif args.dl_worker_type == "HABANA":
        try:
            import habana_dataloader
        except ImportError:
            assert False, "Could Not import habana dataloader package"

    utils.init_distributed_mode(args)
    print(args)
    if args.enable_lazy:
        os.environ["PT_HPU_LAZY_MODE"]="1"
        import habana_frameworks.torch.core as htcore

    if args.is_hmp:
        from habana_frameworks.torch.hpex import hmp
        hmp.convert(opt_level=args.hmp_opt_level, bf16_file_path=args.hmp_bf16,
                    fp32_file_path=args.hmp_fp32, isVerbose=args.hmp_verbose)
    if args.device == 'hpu':
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        device=torch.device('hpu')
    elif args.device == 'gpu':
        if torch.cuda.is_available():
            device_name = "cuda:" + str(args.gpu)
            print(device_name)
            device = torch.device(device_name)
        else:
            assert False, "No GPU device"
    elif args.device == 'cpu':
        device=torch.device('cpu')
    else:
        assert False, "Need device type"
    print('Using', device)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
     # Data loading code
    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        train_sampler=None
        val_sampler = None

    if args.device != 'gpu' and args.workers > 0:
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None
    if args.dl_worker_type == "MP":
        data_loader_type = torch.utils.data.DataLoader
    elif args.dl_worker_type == "HABANA":
        data_loader_type = habana_dataloader.HabanaDataLoader
    train_loader = data_loader_type(
        train_dataset, batch_size=args.batch_size,
        shuffle=True if args.dl_worker_type == "MP" and args.distributed == False else False,
        num_workers=args.workers, pin_memory=True if args.device != 'cpu' else False, sampler=train_sampler)
    val_loader = data_loader_type(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True if args.device != 'cpu' else False, sampler=val_sampler)


    global best_acc1
    # create model
    print("Creating model ", args.model)
    model = get_model(args.model, args.pretrained, args.resume, args.no_aux_logits)
    model.to(device)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if args.enable_lazy:
        from habana_frameworks.torch.hpex.optimizers import FusedSGD
        sgd_optimizer = FusedSGD
    else:
        sgd_optimizer = torch.optim.SGD

    optimizer = sgd_optimizer(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    if args.device=='gpu' and args.is_amp:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(args.resume, map_location='cpu')
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1

            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.device == 'hpu':
        permute_params(model, True, args.enable_lazy)
        permute_momentum(optimizer, True, args.enable_lazy)

    model_for_eval = model

    model_without_ddp = model

    if args.distributed:
        if args.device == 'hpu':
            bucket_size_mb = 100
            model = torch.nn.parallel.DistributedDataParallel(model, bucket_cap_mb=bucket_size_mb, broadcast_buffers=False,
                    gradient_as_bucket_view=True)
        else:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_for_train = model

    if args.evaluate:
        validate(val_loader, model_for_eval, criterion, device, args)
        return

    epoch_time = AverageMeter('EpochTime', ':6.3f')
    start_time = time.time()
    e_time = start_time
    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch, args)
        # train for one epoch
        if args.distributed and args.dl_worker_type != "HABANA":
            train_sampler.set_epoch(epoch)

        print("training epoch ", epoch)
        train(train_loader, model_for_train, criterion, optimizer, epoch, device, args)
        lr_scheduler.step()

        # evaluate on validation set
        print("validating epoch ", epoch)
        acc1 = validate(val_loader, model_for_eval, criterion, device, args)

        # measure elapsed time
        epoch_time.update(time.time() - e_time)
        e_time = time.time()
        epoch_progress = ProgressMeter(
            len(range(args.start_epoch, args.epochs)),
            [epoch_time],
            prefix="END OF EPOCH [{}]:".format(epoch))
        epoch_progress.display(epoch - args.start_epoch + 1)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        if args.save_checkpoint > 0 and (epoch+1)%args.save_checkpoint == 0:
            print("saving ckpt epoch ", epoch)
            if args.device == 'hpu':
                # Permute model parameters from RSCK to KCRS
                permute_params(model_without_ddp, False, args.enable_lazy)
                # Use this model only to copy the state_dict of the actual model
                copy_model = get_model(args.model, args.pretrained, args.resume, args.no_aux_logits)#models.__dict__[args.model](pretrained=args.pretrained)
                state_dict = model_without_ddp.state_dict()
                for k,v in state_dict.items():
                    if 'num_batches_tracked' in k and v.dim() == 1:
                        state_dict[k] = v.squeeze(0)

                copy_model.load_state_dict(state_dict)
                # Permute the weight momentum buffer before saving in checkpoint
                permute_momentum(optimizer, False, args.enable_lazy)

                # Bring all model parameters and optimizer parameters to CPU
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('cpu')

            # Save model parameters in checkpoint
            filename = 'checkpoint_'+str(epoch)+'_'+args.device+'.pth.tar'
            save_checkpoint({
                'epoch': epoch,
                'arch': args.model,
                'state_dict': copy_model.state_dict() if args.device=='hpu' else model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
            }, is_best, filename)
            if args.device == 'hpu':
                #Take back model parameters and optimizer parameters to HPU
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('hpu')
                # Permute back from KCRS to RSCK
                permute_params(model, True, args.enable_lazy)
                permute_momentum(optimizer, True, args.enable_lazy)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Exemplo n.º 4
0
import torch
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()


def test_custom_div_op_function(custom_op_lib_path):
    torch.ops.load_library(custom_op_lib_path)
    print(torch.ops.custom_op.custom_topk)
    a_cpu = torch.rand((6, 6))
    a_hpu = a_cpu.to('hpu')
    a_topk_hpu, a_topk_indices_hpu = torch.ops.custom_op.custom_topk(
        a_hpu, 3, 1, False)
    a_topk_cpu, a_topk_indices_cpu = a_cpu.topk(3, 1)
    assert (torch.equal(a_topk_hpu.detach().cpu(), a_topk_cpu.detach().cpu()))
Exemplo n.º 5
0
def main(args, model_args):
    if args.dl_worker_type == "MP":
        try:
            # Default 'fork' doesn't work with synapse. Use 'forkserver' or 'spawn'
            torch.multiprocessing.set_start_method('spawn')
        except RuntimeError:
            pass
    elif args.dl_worker_type == "HABANA":
        try:
            import habana_dataloader
        except ImportError:
            assert False, "Could Not import habana dataloader package"

    #if args.apex:
    #    if sys.version_info < (3, 0):
    #        raise RuntimeError("Apex currently only supports Python 3. Aborting.")
    #    if amp is None:
    #        raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
    #                           "to enable mixed-precision training.")
    hb_utils.init_distributed_mode(args)
    if hasattr(args, "rank"):
        args.local_rank = args.rank
    print('####################### These are args: ######################')
    print(args)

    model_args.dl_worker_type = args.dl_worker_type
    model_args.world_size = args.world_size
    model_args.process_per_node = args.process_per_node
    model_args.distributed = args.distributed
    model_args.dist_url = args.dist_url

    args.is_master = False
    if args.local_rank in [-1, 0]:
        args.is_master = True
    model_args.is_master = args.is_master
    model_args.local_rank = args.local_rank
    print("############### local_rank is_master #############", model_args.local_rank, model_args.is_master)

    if args.lazy_mode:
        print('######### Lazy Mode ########')
        os.environ["PT_HPU_LAZY_MODE"] = "1"
    model_args.lazy_mode = args.lazy_mode


    if model_args.use_habana is True:
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        device = torch.device("hpu")
        args.n_gpu = 0
        print("########## HPU ##########")

    if args.no_cuda is False:
        if torch.cuda.is_available():
            n_gpu = torch.cuda.device_count()
            if n_gpu > 1:
                torch.cuda.set_device(args.local_rank)
                device = torch.device("cuda", args.local_rank)
            else:
                device = torch.device("cuda")
            args.n_gpu = n_gpu
            print("########## GPU n_gpu ##########", args.n_gpu)
        else:
            device = torch.device("cpu")
            args.n_gpu = 0
            print("########## CPU ##########")

    model_args.device = device
    model_args.n_gpu = args.n_gpu

    #if args.deterministic:
    #    seed = args.seed
    #    random.seed(seed)
    #    if args.device == 'cuda':
    #        torch.cuda.manual_seed(seed)
    #else:
    #    seed = None


    train_df, eval_df = load_train_val_data()

    if model_args.device == 'hpu' and model_args.workers > 0:
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None

    model = Seq2SeqModel(
        encoder_decoder_type="bart",
        encoder_decoder_name="facebook/bart-base",
        args=model_args,
        use_cuda=True if args.n_gpu > 0 else False,
        cuda_device=args.local_rank if args.n_gpu > 0 else -1,
    )

    start_time = time.time()

    model.train_model(train_df, eval_data=eval_df, output_dir=args.output_dir)

    ####################### prediction #######################
    if args.predict and args.local_rank in [-1, 0]:
        to_predict = [
            prefix + ": " + str(input_text)
            for prefix, input_text in zip(
                eval_df["prefix"].tolist(), eval_df["input_text"].tolist()
            )
        ]
        truth = eval_df["target_text"].tolist()

        print("Start testing")
        start_time = time.time()
        #
        preds = model.predict(to_predict)
        #
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Testing time {}'.format(total_time_str))

        os.makedirs(os.path.join(args.output_dir, "predictions"), exist_ok=True)
        pred_time = f"_{datetime.datetime.now()}"
        pred_text = os.path.join(args.output_dir, "predictions", "pred_text"+pred_time+".txt")

        with open(pred_text, "w") as f:
            for i, text in enumerate(eval_df["input_text"].tolist()):
                f.write(str(text) + "\n\n")

                f.write("Truth:\n")
                f.write(truth[i] + "\n\n")

                f.write("Prediction:\n")
                for pred in preds[i]:
                    f.write(str(pred) + "\n")
                f.write(
                    "________________________________________________________________________________\n"
                )

        results = model.compute_metrics(
                    truth, preds
                    )
        print('Prediction results:')
        print(results)

        pred_results = os.path.join(args.output_dir, "predictions", "pred_results"+pred_time+".csv")
        report = pd.DataFrame(results, index=[0])
        report.to_csv(
                    pred_results,
                    index=False,
            )
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Total time {}'.format(total_time_str))
    if args.lazy_mode:
        os.environ.pop("PT_HPU_LAZY_MODE")
Exemplo n.º 6
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=1,
                        metavar='N',
                        help='number of epochs to train (default: 1)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--hpu',
                        action='store_true',
                        default=False,
                        help='Use hpu device')
    parser.add_argument(
        '--use_lazy_mode',
        action='store_true',
        default=False,
        help='Enable lazy mode on hpu device, default eager mode')
    parser.add_argument('--hmp',
                        dest='is_hmp',
                        action='store_true',
                        help='enable hmp mode')
    parser.add_argument('--hmp-bf16',
                        default='ops_bf16_mnist.txt',
                        help='path to bf16 ops list in hmp O1 mode')
    parser.add_argument('--hmp-fp32',
                        default='ops_fp32_mnist.txt',
                        help='path to fp32 ops list in hmp O1 mode')
    parser.add_argument('--hmp-opt-level',
                        default='O1',
                        help='choose optimization level for hmp')
    parser.add_argument('--hmp-verbose',
                        action='store_true',
                        help='enable verbose mode for hmp')
    parser.add_argument('--dl-worker-type',
                        default='MP',
                        type=lambda x: x.upper(),
                        choices=["MT", "MP"],
                        help='select multithreading or multiprocessing')
    parser.add_argument('--world_size',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of total workers (default: 1)')
    parser.add_argument('--process-per-node',
                        default=8,
                        type=int,
                        metavar='N',
                        help='Number of process per node')

    parser.add_argument(
        '--distributed',
        action='store_true',
        help='whether to enable distributed mode and run on multiple devices')
    parser.add_argument('--dist-url',
                        default='env://',
                        help='url used to set up distributed training')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.multiprocessing.set_start_method('spawn')
    if args.hpu:
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        device = torch.device("hpu")
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None

    if args.use_lazy_mode:
        os.environ["PT_HPU_LAZY_MODE"] = "1"
        import habana_frameworks.torch.core as htcore

    if args.is_hmp:
        from habana_frameworks.torch.hpex import hmp
        hmp.convert(opt_level=args.hmp_opt_level,
                    bf16_file_path=args.hmp_bf16,
                    fp32_file_path=args.hmp_fp32,
                    isVerbose=args.hmp_verbose)

    utils.init_distributed_mode(args)

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    dataset1 = datasets.MNIST('../data',
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST('../data', train=False, transform=transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset1)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset2)

        train_loader = torch.utils.data.DataLoader(dataset1,
                                                   batch_size=args.batch_size,
                                                   sampler=train_sampler,
                                                   num_workers=12,
                                                   pin_memory=True,
                                                   drop_last=True)
        test_loader = torch.utils.data.DataLoader(dataset2,
                                                  batch_size=args.batch_size,
                                                  sampler=test_sampler,
                                                  num_workers=12,
                                                  pin_memory=True,
                                                  drop_last=True)
    else:
        train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
        test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    if args.hpu:
        permute_params(model, True, args.use_lazy_mode)
        permute_momentum(optimizer, True, args.use_lazy_mode)

    if args.distributed and args.hpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            bucket_cap_mb=100,
            broadcast_buffers=False,
            gradient_as_bucket_view=True)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()

    if args.save_model:
        if args.hpu:
            torch.save(model.cpu().state_dict(), "mnist_cnn.pt")
        else:
            torch.save(model.state_dict(), "mnist_cnn.pt")
Exemplo n.º 7
0
def main(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    start_time = time.time()
    total_translate_time = 0

    utils.import_user_module(cfg.common)

    if cfg.common.use_habana:
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        device = torch.device("hpu")
        os.environ["PT_HPU_POOL_STRATEGY"] = "0"

    if cfg.common.use_lazy_mode:
        try:
            import habana_frameworks.torch.core as htcore
            os.environ["PT_HPU_LAZY_MODE"] = "1"
        except ImportError:
            assert False, "Could Not import habana_frameworks.torch.core"
    else:
        if os.environ.get('PT_HPU_LAZY_MODE') == None:
            os.environ["PT_HPU_LAZY_MODE"] = "2"

    if cfg.interactive.buffer_size < 1:
        cfg.interactive.buffer_size = 1
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.batch_size = 1

    assert (not cfg.generation.sampling
            or cfg.generation.nbest == cfg.generation.beam
            ), "--sampling requires --nbest to be equal to --beam"
    assert (not cfg.dataset.batch_size
            or cfg.dataset.batch_size <= cfg.interactive.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # Load ensemble
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        if cfg.common.use_habana and cfg.common.bf16:
            model = model.to(device=device, dtype=torch.bfloat16)
        elif cfg.common.use_habana:
            model = model.to(device=device)

        model.prepare_for_inference_(cfg)

    # Initialize generator
    generator = task.build_generator(models, cfg.generation)

    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if cfg.generation.constraints:
        logger.warning(
            "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
        )

    if cfg.interactive.buffer_size > 1:
        logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(cfg.interactive.input,
                                cfg.interactive.buffer_size):
        results = []
        for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }

            if cfg.common.use_habana:
                sample = utils.move_to_habana(sample, device=device)
                if constraints is not None:
                    constraints = constraints.to(device)

            translate_start_time = time.time()
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints)
            translate_time = time.time() - translate_start_time
            total_translate_time += translate_time
            list_constraints = [[] for _ in range(bsz)]
            if cfg.generation.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                        "time": translate_time / len(translations),
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            src_str = ''
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          cfg.common_eval.post_process)
                print("S-{}\t{}".format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_,
                        tgt_dict.string(constraint,
                                        cfg.common_eval.post_process)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))
                if cfg.generation.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print("A-{}\t{}".format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)

    logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(
        time.time() - start_time, total_translate_time))
Exemplo n.º 8
0
def main(args):

    if args.dl_worker_type == "MP":
        try:
            # Default 'fork' doesn't work with synapse. Use 'forkserver' or 'spawn'
            torch.multiprocessing.set_start_method('spawn')
        except RuntimeError:
            pass
    elif args.dl_worker_type == "HABANA":
        try:
            import habana_dataloader
        except ImportError:
            assert False, "Could Not import habana dataloader package"

    if args.run_lazy_mode:
        os.environ["PT_HPU_LAZY_MODE"] = "1"
    if args.is_hmp:
        from habana_frameworks.torch.hpex import hmp
        hmp.convert(opt_level=args.hmp_opt_level,
                    bf16_file_path=args.hmp_bf16,
                    fp32_file_path=args.hmp_fp32,
                    isVerbose=args.hmp_verbose)

    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                "Apex currently only supports Python 3. Aborting.")
        if amp is None:
            raise RuntimeError(
                "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                "to enable mixed-precision training.")

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.device == 'hpu':
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()

    torch.manual_seed(args.seed)

    if args.deterministic:
        seed = args.seed
        random.seed(seed)
        if args.device == 'cuda':
            torch.cuda.manual_seed(seed)
    else:
        seed = None

    device = torch.device(args.device)

    torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        train_dir, val_dir, args.cache_dataset, args.distributed)
    if args.device == 'hpu' and args.workers > 0:
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None

    if args.dl_worker_type == "MP":
        data_loader_type = torch.utils.data.DataLoader
    elif args.dl_worker_type == "HABANA":
        data_loader_type = habana_dataloader.HabanaDataLoader

    data_loader = data_loader_type(dataset,
                                   batch_size=args.batch_size,
                                   sampler=train_sampler,
                                   num_workers=args.workers,
                                   pin_memory=True)

    data_loader_test = data_loader_type(dataset_test,
                                        batch_size=args.batch_size,
                                        sampler=test_sampler,
                                        num_workers=args.workers,
                                        pin_memory=True)

    print("Creating model")
    #Import only resnext101_32x4d from a local copy since torchvision
    # package doesn't support resnext101_32x4d variant
    if 'resnext101_32x4d' in args.model:
        model = resnet_models.__dict__[args.model](pretrained=args.pretrained)
    else:
        model = torchvision.models.__dict__[args.model](
            pretrained=args.pretrained)
    model.to(device)
    if args.channels_last:
        if (device == torch.device('cuda')):
            print('Converting model to channels_last format on CUDA')
            model.to(memory_format=torch.channels_last)
        elif (args.device == 'hpu'):
            print('Converting model params to channels_last format on Habana')
            # TODO:
            # model.to(device).to(memory_format=torch.channels_last)
            # The above model conversion doesn't change the model params
            # to channels_last for many components - e.g. convolution.
            # So we are forced to rearrange such tensors ourselves.

    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    criterion = nn.CrossEntropyLoss()

    if args.run_lazy_mode:
        from habana_frameworks.torch.hpex.optimizers import FusedSGD
        sgd_optimizer = FusedSGD
    else:
        sgd_optimizer = torch.optim.SGD
    optimizer = sgd_optimizer(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    if (args.device == 'hpu'):
        permute_params(model, True, args.run_lazy_mode)
        permute_momentum(optimizer, True, args.run_lazy_mode)

    if args.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.apex_opt_level)
    if args.custom_lr_values is not None:
        lr_vec = lr_vec_fcn([args.lr] + args.custom_lr_values,
                            [0] + args.custom_lr_milestones + [args.epochs])
        lr_scheduler = None
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

    model_for_eval = model

    # TBD: pass the right module for ddp
    model_without_ddp = model

    if args.distributed:
        if args.device == 'hpu':
            # To improve resnext101 dist performance, decrease number of all_reduce calls to 1 by increasing bucket size to 200
            bucket_size_mb = 200 if 'resnext101' in args.model else 100
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                bucket_cap_mb=bucket_size_mb,
                broadcast_buffers=False,
                gradient_as_bucket_view=True)

        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        model_without_ddp = model.module

    model_for_train = model

    if args.resume:
        if (args.device == 'hpu'):
            permute_params(model_without_ddp, False, args.run_lazy_mode)
            permute_momentum(optimizer, False, args.run_lazy_mode)

        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        #Permute the weight momentum buffer before using for checkpoint
        if (args.device == 'hpu'):
            permute_momentum(optimizer, True, args.run_lazy_mode)

        args.start_epoch = checkpoint['epoch'] + 1
        if (args.device == 'hpu'):
            permute_params(model_without_ddp, True, args.run_lazy_mode)

    if args.test_only:
        evaluate(model_for_eval,
                 criterion,
                 data_loader_test,
                 device=device,
                 print_freq=args.print_freq)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        # Setting epoch is done by Habana dataloader internally
        if args.distributed and args.dl_worker_type != "HABANA":
            train_sampler.set_epoch(epoch)

        if lr_scheduler is None:
            adjust_learning_rate(optimizer, epoch, lr_vec)

        train_one_epoch(model_for_train,
                        criterion,
                        optimizer,
                        data_loader,
                        device,
                        epoch,
                        print_freq=args.print_freq,
                        apex=args.apex)
        if lr_scheduler is not None:
            lr_scheduler.step()
        evaluate(model_for_eval,
                 criterion,
                 data_loader_test,
                 device=device,
                 print_freq=args.print_freq)

        if (args.output_dir and args.save_checkpoint):
            if args.device == 'hpu':
                permute_params(model_without_ddp, False, args.run_lazy_mode)
                # Use this model only to copy the state_dict of the actual model
                copy_model = resnet_models.__dict__[args.model](
                    pretrained=args.pretrained
                ) if 'resnext101_32x4d' in args.model else torchvision.models.__dict__[
                    args.model](pretrained=args.pretrained)

                copy_model.load_state_dict(model_without_ddp.state_dict())
                # Permute the weight momentum buffer before saving in checkpoint
                permute_momentum(optimizer, False, args.run_lazy_mode)

                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('cpu')

                checkpoint = {
                    'model':
                    copy_model.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'lr_scheduler':
                    None
                    if lr_scheduler is None else lr_scheduler.state_dict(),
                    'epoch':
                    epoch,
                    'args':
                    args
                }
                utils.save_on_master(
                    checkpoint,
                    os.path.join(args.output_dir,
                                 'model_{}.pth'.format(epoch)))
                utils.save_on_master(
                    checkpoint, os.path.join(args.output_dir,
                                             'checkpoint.pth'))

                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to('hpu')
                permute_params(model_without_ddp, True, args.run_lazy_mode)
                permute_momentum(optimizer, True, args.run_lazy_mode)

            else:
                checkpoint = {
                    'model':
                    model_without_ddp.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'lr_scheduler':
                    None
                    if lr_scheduler is None else lr_scheduler.state_dict(),
                    'epoch':
                    epoch,
                    'args':
                    args
                }
                utils.save_on_master(
                    checkpoint,
                    os.path.join(args.output_dir,
                                 'model_{}.pth'.format(epoch)))
                utils.save_on_master(
                    checkpoint, os.path.join(args.output_dir,
                                             'checkpoint.pth'))

    if args.run_lazy_mode:
        os.environ.pop("PT_HPU_LAZY_MODE")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " +
        ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help=
        "Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written.",
    )

    # Distillation parameters (optional)
    parser.add_argument(
        "--teacher_type",
        default=None,
        type=str,
        help=
        "Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for distillation.",
    )
    parser.add_argument(
        "--teacher_name_or_path",
        default=None,
        type=str,
        help=
        "Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
    )
    parser.add_argument(
        "--alpha_ce",
        default=0.5,
        type=float,
        help="Distillation loss linear weight. Only for distillation.")
    parser.add_argument(
        "--alpha_squad",
        default=0.5,
        type=float,
        help="True SQuAD loss linear weight. Only for distillation.")
    parser.add_argument(
        "--temperature",
        default=2.0,
        type=float,
        help="Distillation temperature. Only for distillation.")

    # Other parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        help="The input data dir. Should contain the .json files for the task."
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--train_file",
        default=None,
        type=str,
        help=
        "The input training file. If a data dir is specified, will look for the file there"
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help=
        "The input evaluation file. If a data dir is specified, will look for the file there"
        +
        "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from huggingface.co",
    )

    parser.add_argument(
        "--version_2_with_negative",
        action="store_true",
        help=
        "If true, the SQuAD examples contain some that do not have an answer.",
    )
    parser.add_argument(
        "--null_score_diff_threshold",
        type=float,
        default=0.0,
        help=
        "If null_score - best_non_null is greater than the threshold predict null.",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument("--do_train",
                        action="store_true",
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action="store_true",
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action="store_true",
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json output file.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--verbose_logging",
        action="store_true",
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.",
    )

    parser.add_argument("--logging_steps",
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument("--save_steps",
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument("--overwrite_output_dir",
                        action="store_true",
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        "--overwrite_cache",
        action="store_true",
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--world_size", type=int, default=1, help="world size")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--server_ip",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")
    parser.add_argument("--server_port",
                        type=str,
                        default="",
                        help="Can be used for distant debugging.")

    parser.add_argument(
        "--threads",
        type=int,
        default=1,
        help="multiple threads for converting example to features")
    parser.add_argument("--hpu", action="store_true", help="HPU run")
    parser.add_argument("--hmp", action="store_true", help="Enable HMP")
    parser.add_argument("--hmp_bf16",
                        type=str,
                        default='',
                        help="List of ops to be run in BF16 for HPU")
    parser.add_argument("--hmp_fp32",
                        type=str,
                        default='',
                        help="List of ops to be run in FP32 for HPU")
    parser.add_argument("--hmp_opt_level",
                        type=str,
                        default='O1',
                        help="Optimization level for HMP")
    parser.add_argument("--hmp_verbose",
                        action="store_true",
                        help="Optimization level for HMP")
    parser.add_argument('--use_lazy_mode',
                        action='store_true',
                        help='run model in lazy execution mode')
    parser.add_argument("--optimizer",
                        type=str,
                        default="AdamW",
                        help="type of optimizer.")

    args = parser.parse_args()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )

    if args.use_lazy_mode:
        os.environ["PT_HPU_LAZY_MODE"] = "1"
        os.environ["PT_HPU_LOWER_AS_STRIDED"] = "1"

    if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
            and args.do_train and not args.overwrite_output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.hpu:
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        os.environ["MAX_WAIT_ATTEMPTS"] = "50"
        device = torch.device("hpu")
        args.n_gpu = 1
        world_size = 1
        rank = -1
        if ('WORLD_SIZE' in os.environ and 'RANK' in os.environ):
            world_size = int(os.environ["WORLD_SIZE"])
            rank = int(os.environ["RANK"])
            if 'LOCAL_RANK' in os.environ:
                args.local_rank = int(os.environ["LOCAL_RANK"])
            if args.local_rank in [-1, 0]:
                logger.info("Torch distributed launch used")
        elif ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ
              and 'OMPI_COMM_WORLD_SIZE' in os.environ
              and 'OMPI_COMM_WORLD_RANK' in os.environ):
            args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
            world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
            rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
            if args.local_rank in [-1, 0]:
                logger.info("MPI environment variables set")
        else:
            try:
                global mpi_comm
                from mpi4py import MPI
                mpi_comm = MPI.COMM_WORLD
                world_size = mpi_comm.Get_size()
                if world_size > 1:
                    rank = mpi_comm.Get_rank()
                    args.local_rank = rank
                else:
                    raise ("Single MPI process")
            except Exception as e:
                if args.local_rank in [-1, 0]:
                    logger.info("Single node run")

        if args.local_rank != -1:
            try:
                import habana_frameworks.torch.core.hccl
            except:
                assert False, "Could not import habana_frameworks.torch.core"
            os.environ["ID"] = str(args.local_rank)
            torch.distributed.init_process_group(backend="hccl",
                                                 rank=args.local_rank,
                                                 world_size=world_size)
            if args.local_rank in [-1, 0]:
                logger.info("Enable distributed run")
        if args.use_lazy_mode:
            if args.local_rank in [-1, 0]:
                logger.info("Enable habana lazy mode")

        if args.hmp:
            from habana_frameworks.torch.hpex import hmp
            hmp.convert(opt_level=args.hmp_opt_level,
                        bf16_file_path=args.hmp_bf16,
                        fp32_file_path=args.hmp_fp32,
                        isVerbose=args.hmp_verbose)
    elif args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
        logger.info("GPU enabled")

    args.device = device

    if args.local_rank in [-1, 0]:
        logger.warning(
            "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
            args.local_rank,
            device,
            args.n_gpu,
            bool(args.local_rank != -1),
            args.fp16,
        )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    config.return_dict = False
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name
        if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    if args.teacher_type is not None:
        assert args.teacher_name_or_path is not None
        assert args.alpha_ce > 0.0
        assert args.alpha_ce + args.alpha_squad > 0.0
        assert args.teacher_type != "distilbert", "We constraint teachers not to be of type DistilBERT."
        teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[
            args.teacher_type]
        teacher_config = teacher_config_class.from_pretrained(
            args.teacher_name_or_path,
            cache_dir=args.cache_dir if args.cache_dir else None)
        teacher_config.return_dict = False
        teacher = teacher_model_class.from_pretrained(
            args.teacher_name_or_path,
            config=teacher_config,
            cache_dir=args.cache_dir if args.cache_dir else None)
        teacher.to(args.device)
    else:
        teacher = None

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)

    if args.local_rank in [-1, 0]:
        logger.info("Training/evaluation parameters %s", args)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args,
                                                tokenizer,
                                                evaluate=False,
                                                output_examples=False)
        global_step, tr_loss = train(args,
                                     train_dataset,
                                     model,
                                     tokenizer,
                                     teacher=teacher)
        if args.local_rank in [-1, 0]:
            logger.info(" global_step = %s, average loss = %s", global_step,
                        tr_loss)

    # Save the trained model and the tokenizer
    if args.do_train and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):
        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if args.hpu:
            model_to_save = (model.module.to("cpu") if hasattr(
                model, "module") else model.to("cpu"))
        else:
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(
            args.output_dir, do_lower_case=args.do_lower_case)
        model.to(args.device)

    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        if args.do_train:
            logger.info(
                "Loading checkpoints saved during training for evaluation")
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(
                    glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME,
                              recursive=True)))

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            # Reload the model
            global_step = checkpoint.split(
                "-")[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)

            # Evaluate
            result = evaluate(args, model, tokenizer, prefix=global_step)

            result = dict(
                (k + ("_{}".format(global_step) if global_step else ""), v)
                for k, v in result.items())
            results.update(result)

    logger.info("Results: {}".format(results))

    return results
Exemplo n.º 10
0
def train300_mlperf_coco(args):
    global torch
    from coco import COCO
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    args.distributed = False
    if args.use_hpu:
        if 'WORLD_SIZE' in os.environ:
            args.distributed = int(os.environ['WORLD_SIZE']) > 1
            args.world_size = int(os.environ['WORLD_SIZE'])
            print("world_size = {}".format(args.world_size))
            print("distributed={}".format(args.distributed))
    if use_cuda:
        try:
            from apex.parallel import DistributedDataParallel as DDP
            if 'WORLD_SIZE' in os.environ:
                args.distributed = int(os.environ['WORLD_SIZE']) > 1
        except:
            raise ImportError(
                "Please install APEX from https://github.com/nvidia/apex")

    use_hpu = args.use_hpu
    hpu_channels_last = args.hpu_channels_last
    hpu_lazy_mode = args.hpu_lazy_mode
    is_hmp = args.is_hmp
    device = torch.device('cpu')
    data_loader_type = DataLoader
    if use_hpu:
        device = torch.device('hpu')
        if args.distributed:
            os.environ["MAX_WAIT_ATTEMPTS"] = "90"
        if hpu_lazy_mode:
            os.environ["PT_HPU_LAZY_MODE"] = "1"
        else:
            os.environ["PT_HPU_LAZY_MODE"] = "2"
        if is_hmp:
            if not args.hmp_bf16:
                raise IOError("Please provide list of BF16 ops")
            if not args.hmp_fp32:
                raise IOError("Please provide list of FP32 ops")
            from habana_frameworks.torch.hpex import hmp
            hmp.convert(opt_level=args.hmp_opt_level,
                        bf16_file_path=args.hmp_bf16,
                        fp32_file_path=args.hmp_fp32,
                        isVerbose=args.hmp_verbose)
        from habana_frameworks.torch.utils.library_loader import load_habana_module
        load_habana_module()
        # TODO - add dataloader

    local_seed = args.seed
    if args.distributed:
        # necessary pytorch imports
        import torch.utils.data.distributed
        import torch.distributed as dist
        if use_hpu:
            args.dist_backend = 'hccl'
            import habana_frameworks.torch.core.hccl
            os.environ["ID"] = os.environ["RANK"]
            dist.init_process_group(args.dist_backend, init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device, use_hpu=True)
            local_seed = (args.seed + dist.get_rank()) % 2**32
        elif args.no_cuda:
            device = torch.device('cpu')
        else:
            torch.cuda.set_device(args.local_rank)
            device = torch.device('cuda')
            dist.init_process_group(backend='nccl', init_method='env://')
            # set seeds properly
            args.seed = broadcast_seeds(args.seed, device)
            local_seed = (args.seed + dist.get_rank()) % 2**32
    mllogger.event(key=mllog_const.SEED, value=local_seed)
    torch.manual_seed(local_seed)
    np.random.seed(seed=local_seed)
    random.seed(local_seed)  # amorgenstern
    torch.cuda.manual_seed(local_seed)  # amorgenstern

    args.rank = dist.get_rank() if args.distributed else args.local_rank
    print("args.rank = {}".format(args.rank))
    print("local rank = {}".format(args.local_rank))
    print("distributed={}".format(args.distributed))

    if use_hpu and is_hmp:
        with hmp.disable_casts():
            dboxes = dboxes300_coco()
            encoder = Encoder(dboxes)
    else:
        dboxes = dboxes300_coco()
        encoder = Encoder(dboxes)

    input_size = 300
    if use_hpu and is_hmp:
        with hmp.disable_casts():
            train_trans = SSDTransformer(
                dboxes, (input_size, input_size),
                val=False,
                num_cropping_iterations=args.num_cropping_iterations)
            val_trans = SSDTransformer(dboxes, (input_size, input_size),
                                       val=True)
    else:
        train_trans = SSDTransformer(
            dboxes, (input_size, input_size),
            val=False,
            num_cropping_iterations=args.num_cropping_iterations)
        val_trans = SSDTransformer(dboxes, (input_size, input_size), val=True)

    val_annotate = os.path.join(args.data,
                                "annotations/instances_val2017.json")
    val_coco_root = os.path.join(args.data, "val2017")
    train_annotate = os.path.join(args.data,
                                  "annotations/instances_train2017.json")
    train_coco_root = os.path.join(args.data, "train2017")

    if use_hpu and is_hmp:
        with hmp.disable_casts():
            cocoGt = COCO(annotation_file=val_annotate)
            train_coco = COCODetection(train_coco_root, train_annotate,
                                       train_trans)
            val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    else:
        cocoGt = COCO(annotation_file=val_annotate)
        train_coco = COCODetection(train_coco_root, train_annotate,
                                   train_trans)
        val_coco = COCODetection(val_coco_root, val_annotate, val_trans)
    mllogger.event(key=mllog_const.TRAIN_SAMPLES, value=len(train_coco))
    mllogger.event(key=mllog_const.EVAL_SAMPLES, value=len(val_coco))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_coco)
    else:
        train_sampler = None
    if use_hpu:
        # patch torch cuda functions that are being unconditionally invoked
        # in the multiprocessing data loader
        torch.cuda.current_device = lambda: None
        torch.cuda.set_device = lambda x: None
    train_dataloader = data_loader_type(train_coco,
                                        batch_size=args.batch_size,
                                        shuffle=(train_sampler is None),
                                        sampler=train_sampler,
                                        num_workers=args.num_workers)
    # set shuffle=True in DataLoader
    if args.rank == 0:
        val_dataloader = data_loader_type(val_coco,
                                          batch_size=args.val_batch_size
                                          or args.batch_size,
                                          shuffle=False,
                                          sampler=None,
                                          num_workers=args.num_workers)
    else:
        val_dataloader = None

    ssd300 = SSD300(train_coco.labelnum, model_path=args.pretrained_backbone)
    if args.checkpoint is not None:
        print("loading model checkpoint", args.checkpoint)
        od = torch.load(args.checkpoint, map_location=torch.device('cpu'))
        ssd300.load_state_dict(od["model"])
    ssd300.train()
    if use_cuda:
        ssd300.cuda()
    if use_hpu and is_hmp:
        with hmp.disable_casts():
            loss_func = Loss(dboxes, use_hpu=use_hpu, hpu_device=device)
    else:
        loss_func = Loss(dboxes, use_hpu=use_hpu, hpu_device=device)
    if use_cuda:
        loss_func.cuda()

    if use_hpu:
        ssd300.to(device)
        loss_func.to(device)

    if args.distributed:
        N_gpu = torch.distributed.get_world_size()
    else:
        N_gpu = 1

    global_batch_size = N_gpu * args.batch_size
    mllogger.event(key=mllog_const.GLOBAL_BATCH_SIZE, value=global_batch_size)
    # Reference doesn't support group batch norm, so bn_span==local_batch_size
    mllogger.event(key=mllog_const.MODEL_BN_SPAN, value=args.batch_size)
    current_lr = args.lr * (global_batch_size / 32)

    assert args.batch_size % args.batch_splits == 0, "--batch-size must be divisible by --batch-splits"
    fragment_size = args.batch_size // args.batch_splits
    if args.batch_splits != 1:
        print("using gradient accumulation with fragments of size {}".format(
            fragment_size))

    current_momentum = 0.9
    sgd_optimizer = torch.optim.SGD
    if use_hpu and hpu_lazy_mode:
        from habana_frameworks.torch.hpex.optimizers import FusedSGD
        sgd_optimizer = FusedSGD
    optim = sgd_optimizer(ssd300.parameters(),
                          lr=current_lr,
                          momentum=current_momentum,
                          weight_decay=args.weight_decay)
    if use_hpu:
        permute_params(model=ssd300,
                       to_filters_last=True,
                       lazy_mode=hpu_lazy_mode)
        permute_momentum(optimizer=optim,
                         to_filters_last=True,
                         lazy_mode=hpu_lazy_mode)

    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_BASE_LR,
              value=current_lr)
    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_WEIGHT_DECAY,
              value=args.weight_decay)

    # parallelize
    if args.distributed:
        if use_hpu:
            ssd300 = torch.nn.parallel.DistributedDataParallel(
                ssd300,
                bucket_cap_mb=100,
                broadcast_buffers=False,
                gradient_as_bucket_view=True)
        else:
            ssd300 = DDP(ssd300)

    iter_num = args.iteration
    end_iter_num = args.end_iteration
    if end_iter_num:
        print("--end-iteration set to: {}".format(end_iter_num))
        assert end_iter_num > iter_num, "--end-iteration must have a value > --iteration"
    avg_loss = 0.0
    if use_hpu:
        loss_iter = list()
    inv_map = {v: k for k, v in val_coco.label_map.items()}
    success = torch.zeros(1)
    if use_cuda:
        success = success.cuda()
    if use_hpu:
        success = success.to(device)

    if args.warmup:
        nonempty_imgs = len(train_coco)
        wb = int(args.warmup * nonempty_imgs / (N_gpu * args.batch_size))
        ssd_print(device=device,
                  use_hpu=use_hpu,
                  key=mllog_const.OPT_LR_WARMUP_STEPS,
                  value=wb)
        warmup_step = lambda iter_num, current_lr: lr_warmup(
            optim, wb, iter_num, current_lr, args)
    else:
        warmup_step = lambda iter_num, current_lr: None

    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_LR_WARMUP_FACTOR,
              value=args.warmup_factor)
    ssd_print(device=device,
              use_hpu=use_hpu,
              key=mllog_const.OPT_LR_DECAY_BOUNDARY_EPOCHS,
              value=args.lr_decay_schedule)
    mllogger.start(key=mllog_const.BLOCK_START,
                   metadata={
                       mllog_const.FIRST_EPOCH_NUM: 1,
                       mllog_const.EPOCH_COUNT: args.epochs
                   })

    optim.zero_grad(set_to_none=True)
    if use_hpu:
        start = time.time()
    for epoch in range(args.epochs):
        mllogger.start(key=mllog_const.EPOCH_START,
                       metadata={mllog_const.EPOCH_NUM: epoch})
        # set the epoch for the sampler
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch in args.lr_decay_schedule:
            current_lr *= 0.1
            print("")
            print("lr decay step #{num}".format(
                num=args.lr_decay_schedule.index(epoch) + 1))
            for param_group in optim.param_groups:
                param_group['lr'] = current_lr

        for nbatch, (img, img_id, img_size, bbox,
                     label) in enumerate(train_dataloader):
            current_batch_size = img.shape[0]
            # Split batch for gradient accumulation
            img = torch.split(img, fragment_size)
            bbox = torch.split(bbox, fragment_size)
            label = torch.split(label, fragment_size)

            for (fimg, fbbox, flabel) in zip(img, bbox, label):
                current_fragment_size = fimg.shape[0]
                if not use_hpu:
                    trans_bbox = fbbox.transpose(1, 2).contiguous()
                if use_cuda:
                    fimg = fimg.cuda()
                    trans_bbox = trans_bbox.cuda()
                    flabel = flabel.cuda()
                if use_hpu:
                    fimg = fimg.to(device)
                    if hpu_channels_last:
                        fimg = fimg.contiguous(
                            memory_format=torch.channels_last)
                        if hpu_lazy_mode:
                            mark_step()
                    if is_hmp:
                        with hmp.disable_casts():
                            #TODO revert after SW-58188 is fixed
                            trans_bbox = fbbox.to(device).transpose(
                                1, 2).contiguous()
                            flabel = flabel.to(device)
                    else:
                        #TODO revert after SW-58188 is fixed
                        trans_bbox = fbbox.to(device).transpose(
                            1, 2).contiguous()
                        flabel = flabel.to(device)
                fimg = Variable(fimg, requires_grad=True)
                if args.lowp:  # amorgenstern
                    import lowp
                    with lowp.Lowp(mode='BF16',
                                   warn_patched=True,
                                   warn_not_patched=True):
                        ploc, plabel = ssd300(fimg)
                        gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                Variable(flabel, requires_grad=False)
                        loss = loss_func(ploc, plabel, gloc, glabel)
                else:
                    ploc, plabel = ssd300(fimg)
                    if use_hpu and is_hmp:
                        with hmp.disable_casts():
                            gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                    Variable(flabel, requires_grad=False)
                            loss = loss_func(ploc.float(), plabel.float(),
                                             gloc, glabel)
                    else:
                        gloc, glabel = Variable(trans_bbox, requires_grad=False), \
                                Variable(flabel, requires_grad=False)
                        loss = loss_func(ploc, plabel, gloc, glabel)
                loss = loss * (current_fragment_size / current_batch_size
                               )  # weighted mean
                if use_hpu and hpu_lazy_mode and args.distributed:
                    mark_step()
                loss.backward()
                if use_hpu and hpu_lazy_mode:
                    mark_step()

            warmup_step(iter_num, current_lr)
            if use_hpu and is_hmp:
                with hmp.disable_casts():
                    optim.step()
            else:
                optim.step()
            optim.zero_grad(set_to_none=True)
            if use_hpu:
                loss_iter.append(loss.clone().detach())
            else:
                if not np.isinf(loss.item()):
                    avg_loss = 0.999 * avg_loss + 0.001 * loss.item()
            if use_hpu and hpu_lazy_mode:
                mark_step()
            if use_hpu:
                if args.log_interval and not iter_num % args.log_interval:
                    cur_loss = 0.0
                    for i, x in enumerate(loss_iter):
                        cur_loss = x.cpu().item()
                        if not np.isinf(cur_loss):
                            avg_loss = 0.999 * avg_loss + 0.001 * cur_loss
                    if args.rank == 0:
                        print("Rank: {:6d}, Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
                            .format(args.rank, iter_num, cur_loss, avg_loss))
                    loss_iter = list()
            else:
                if args.rank == 0 and args.log_interval and not iter_num % args.log_interval:
                    print("Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}"\
                        .format(iter_num, loss.item(), avg_loss))
            iter_num += 1
            if use_hpu and iter_num == 50:
                start = time.time()
            if end_iter_num and iter_num >= end_iter_num:
                if use_hpu:
                    print("Training Ended, total time: {:.2f} s".format(
                        time.time() - start))
                break

        if (args.val_epochs and (epoch+1) in args.val_epochs) or \
           (args.val_interval and not (epoch+1) % args.val_interval):
            if args.distributed:
                world_size = float(dist.get_world_size())
                for bn_name, bn_buf in ssd300.module.named_buffers(
                        recurse=True):
                    if ('running_mean' in bn_name) or ('running_var'
                                                       in bn_name):
                        dist.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
                        bn_buf /= world_size
                        ssd_print(device=device,
                                  use_hpu=use_hpu,
                                  key=mllog_const.MODEL_BN_SPAN,
                                  value=bn_buf)
            if args.rank == 0:
                if use_hpu:
                    print("Training Ended, total time: {:.2f} s".format(
                        time.time() - start))
                if not args.no_save:
                    print("")
                    print("saving model...")
                    if use_hpu:
                        permute_params(model=ssd300,
                                       to_filters_last=False,
                                       lazy_mode=hpu_lazy_mode)
                        ssd300_copy = SSD300(
                            train_coco.labelnum,
                            model_path=args.pretrained_backbone)
                        if args.distributed:
                            ssd300_copy.load_state_dict(
                                ssd300.module.state_dict())
                        else:
                            ssd300_copy.load_state_dict(ssd300.state_dict())
                        torch.save(
                            {
                                "model": ssd300_copy.state_dict(),
                                "label_map": train_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))
                        permute_params(model=ssd300,
                                       to_filters_last=True,
                                       lazy_mode=hpu_lazy_mode)
                    else:
                        torch.save(
                            {
                                "model": ssd300.state_dict(),
                                "label_map": train_coco.label_info
                            }, "./models/iter_{}.pt".format(iter_num))

                if coco_eval(ssd300,
                             val_dataloader,
                             cocoGt,
                             encoder,
                             inv_map,
                             args.threshold,
                             epoch + 1,
                             iter_num,
                             log_interval=args.log_interval,
                             use_cuda=use_cuda,
                             use_hpu=use_hpu,
                             hpu_device=device,
                             is_hmp=is_hmp,
                             hpu_channels_last=hpu_channels_last,
                             hpu_lazy_mode=hpu_lazy_mode,
                             nms_valid_thresh=args.nms_valid_thresh):
                    success = torch.ones(1)
                    if use_cuda:
                        success = success.cuda()
                    if use_hpu:
                        success = success.to(device)
            if args.distributed:
                dist.broadcast(success, 0)
            if success[0]:
                return True
            mllogger.end(key=mllog_const.EPOCH_STOP,
                         metadata={mllog_const.EPOCH_NUM: epoch})
    mllogger.end(key=mllog_const.BLOCK_STOP,
                 metadata={
                     mllog_const.FIRST_EPOCH_NUM: 1,
                     mllog_const.EPOCH_COUNT: args.epochs
                 })

    return False