Ejemplo n.º 1
0
    def forward(self, outputs, targets):
        outputs_without_aux = {
            k: v
            for k, v in outputs.items() if k != "aux_outputs"
        }
        indices = self.matcher(outputs_without_aux, targets)

        num_boxes = sum([len(t["labels"]) for t in targets])
        num_boxes = torch.as_tensor([num_boxes],
                                    dtype=torch.float,
                                    device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

        losses = {}
        for loss in self.losses:
            losses.update(
                self.get_loss(loss, outputs, targets, indices, num_boxes))

        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == "masks":
                        continue
                    kwargs = {}
                    if loss == "labels":
                        kwargs = {"log": False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices,
                                           num_boxes, **kwargs)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses
Ejemplo n.º 2
0
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):

    global skipped_steps
    if args.allreduce_post_accumulation:
        # manually allreduce gradients after all accumulation steps
        # check for Inf/NaN
        # 1. allocate an uninitialized buffer for flattened gradient
        loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1
        master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
        flat_grad_size = sum(p.numel() for p in master_grads)
        allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
        flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
        # 2. combine unflattening and predivision of unscaled 'raw' gradient
        allreduced_views = apex_C.unflatten(flat_raw, master_grads)
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536,
            overflow_buf,
            [master_grads, allreduced_views],
            loss_scale / (get_world_size() * args.gradient_accumulation_steps))
        # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
        torch.distributed.all_reduce(flat_raw)
        # 4. combine unscaling and unflattening of allreduced gradient
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536,
            overflow_buf,
            [allreduced_views, master_grads],
            1./loss_scale)
        # 5. update loss scale
        if args.fp16:
            scaler = _amp_state.loss_scalers[0]
            old_overflow_buf = scaler._overflow_buf
            scaler._overflow_buf = overflow_buf
            had_overflow = scaler.update_scale()
            scaler._overfloat_buf = old_overflow_buf
        else:
            had_overflow = 0
        # 6. call optimizer step function
        if had_overflow == 0:
            optimizer.step()
            global_step += 1
        else:
            # Overflow detected, print message and clear gradients
            skipped_steps += 1
            if is_main_process():
                scaler = _amp_state.loss_scalers[0]
                dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()})
            if _amp_state.opt_properties.master_weights:
                for param in optimizer._amp_stash.all_fp32_from_fp16_params:
                    param.grad = None
        for param in model.parameters():
            param.grad = None
    else:
        if args.apply_optimizer > 0:
            optimizer.step()
        # optimizer.zero_grad()
        for param in model.parameters():
            param.grad = None
        global_step += 1

    return global_step
Ejemplo n.º 3
0
def get_dataset_from_features(features, batch_size, drop_remainder=True, ngpu=8, mode="train", v2=False):
    """Input function for training"""

    all_input_ids = tf.convert_to_tensor([f.input_ids for f in features], dtype=tf.int64)
    all_input_mask = tf.convert_to_tensor([f.attention_mask for f in features], dtype=tf.int64)
    all_segment_ids = tf.convert_to_tensor([f.token_type_ids for f in features], dtype=tf.int64)
    all_start_pos = tf.convert_to_tensor([f.start_position for f in features], dtype=tf.int64)
    all_end_pos = tf.convert_to_tensor([f.end_position for f in features], dtype=tf.int64)

    # if v2 else None:
    all_cls_index = tf.convert_to_tensor([f.cls_index for f in features], dtype=tf.int64)
    all_p_mask = tf.convert_to_tensor([f.p_mask for f in features], dtype=tf.float32)
    all_is_impossible = tf.convert_to_tensor([f.is_impossible for f in features], dtype=tf.float32)

    dataset = tf.data.Dataset.from_tensor_slices(
        (all_input_ids, all_input_mask, all_segment_ids, all_start_pos, all_end_pos)
        + (all_cls_index, all_p_mask, all_is_impossible))
    if ngpu > 1:
        dataset = dataset.shard(get_world_size(), get_rank())

    if mode == "train":
        dataset = dataset.shuffle(batch_size * 3)
    # dataset = dataset.map(self._preproc_samples,
    #                      num_parallel_calls=multiprocessing.cpu_count()//self._num_gpus)
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
    dataset = dataset.prefetch(batch_size)

    return dataset
Ejemplo n.º 4
0
def take_optimizer_step(args, optimizer, grad_scaler, model, overflow_buf,
                        global_step):

    global skipped_steps
    if args.allreduce_post_accumulation:
        # manually allreduce gradients after all accumulation steps
        # check for Inf/NaN
        # 1. allocate an uninitialized buffer for flattened gradient
        loss_scale = grad_scaler._get_scale_async() if args.fp16 else 1.
        master_grads = [
            p.grad for p in model.parameters() if p.grad is not None
        ]
        flat_grad_size = sum(p.numel() for p in master_grads)
        allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
        flat_raw = torch.empty(flat_grad_size,
                               device='cuda',
                               dtype=allreduce_dtype)
        # 2. combine unflattening and predivision of unscaled 'raw' gradient
        allreduced_views = apex_C.unflatten(flat_raw, master_grads)
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(
            65536, overflow_buf, [master_grads, allreduced_views],
            loss_scale / (get_world_size() * args.gradient_accumulation_steps))
        # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
        torch.distributed.all_reduce(flat_raw)
        # 4. combine unscaling and unflattening of allreduced gradient
        overflow_buf.zero_()
        amp_C.multi_tensor_scale(65536, overflow_buf,
                                 [allreduced_views, master_grads], 1.)
        # 5. update loss scale
        if args.fp16:
            had_overflow = overflow_buf.item()
        else:
            had_overflow = 0
        # 6. call optimizer step function
        if had_overflow == 0:
            global_step += 1
        else:
            # Overflow detected, print message and clear gradients
            skipped_steps += 1
    else:
        global_step += 1
    grad_scaler.step(optimizer)
    grad_scaler.update()
    optimizer.zero_grad(set_to_none=True)

    return global_step
Ejemplo n.º 5
0
def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return loss_dict
    with torch.no_grad():
        loss_names = []
        all_losses = []
        for k in sorted(loss_dict.keys()):
            loss_names.append(k)
            all_losses.append(loss_dict[k])
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses
Ejemplo n.º 6
0
def inference(model, data_loader, dataset_name, device='cuda', output_folder=None,
              expected_results=(), expected_results_sigma_tol=4):
    device = torch.device(device)
    num_devices = get_world_size()
    logger = logging.getLogger("RetinaNet.inference")
    dataset = data_loader.dataset
    logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
    total_timer = Timer()
    inference_timer = Timer()
    total_timer.tic()
    predictions = compute_on_dataset(model, data_loader, device, inference_timer)
    # wait for all processes to complete before measuring the time
    synchronize()
    total_time = total_timer.toc()
    total_time_str = get_time_str(total_time)
    logger.info(
        "Total run time: {} ({} s / img per device, on {} devices)".format(
            total_time_str, total_time * num_devices / len(dataset), num_devices
        )
    )

    predictions = accumulate_predictions_from_multiple_gpus(predictions)
    if not is_main_process():
        return

    if output_folder:
        torch.save(predictions, os.path.join(output_folder, "predictions.pth"))

    extra_args = dict(
        expected_results=expected_results,
        expected_results_sigma_tol=expected_results_sigma_tol,
    )

    return evaluate(dataset=dataset,
                    predictions=predictions,
                    output_folder=output_folder,
                    **extra_args)
Ejemplo n.º 7
0
def main(args):
    utils.my_init_distributed_mode(args)
    print(args)
    # if args.distillation_type != 'none' and args.finetune and not args.eval:
    #     raise NotImplementedError("Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    # if True:  # args.distributed:
    #     num_tasks = utils.get_world_size()
    #     global_rank = utils.get_rank()
    #     if args.repeated_aug:
    #         sampler_train = RASampler(
    #             dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    #         )
    #     else:
    #         sampler_train = torch.utils.data.DistributedSampler(
    #             dataset_train,
    #             # num_replicas=num_tasks,
    #             num_replicas=0,
    #             rank=global_rank, shuffle=True
    #         )
    #     if args.dist_eval:
    #         if len(dataset_val) % num_tasks != 0:
    #             print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
    #                   'This will slightly alter validation results as extra duplicate entries are added to achieve '
    #                   'equal num of samples per-process.')
    #         sampler_val = torch.utils.data.DistributedSampler(
    #             dataset_val,
    #             # num_replicas=num_tasks,
    #             num_replicas=0,
    #             rank=global_rank, shuffle=False)
    #     else:
    #         sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    # else:
    #     sampler_train = torch.utils.data.RandomSampler(dataset_train)
    #     sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    #
    # data_loader_train = torch.utils.data.DataLoader(
    #     dataset_train, sampler=sampler_train,
    #     batch_size=args.batch_size,
    #     num_workers=args.num_workers,
    #     pin_memory=args.pin_mem,
    #     drop_last=True,
    # )
    #
    # data_loader_val = torch.utils.data.DataLoader(
    #     dataset_val, sampler=sampler_val,
    #     batch_size=int(1.5 * args.batch_size),
    #     num_workers=args.num_workers,
    #     pin_memory=args.pin_mem,
    #     drop_last=False
    # )
    #

    if args.distributed:
        if args.cache_mode:
            sampler_train = samplers.NodeDistributedSampler(dataset_train)
            sampler_val = samplers.NodeDistributedSampler(dataset_val,
                                                          shuffle=False)
        else:
            sampler_train = samplers.DistributedSampler(dataset_train)
            sampler_val = samplers.DistributedSampler(dataset_val,
                                                      shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        args.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    data_loader_val = DataLoader(dataset_val,
                                 args.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )
    model_without_ddp = model

    # # there are bugs
    # if args.finetune:
    #     if args.finetune.startswith('https'):
    #         checkpoint = torch.hub.load_state_dict_from_url(
    #             args.finetune, map_location='cpu', check_hash=True)
    #     else:
    #         checkpoint = torch.load(args.finetune, map_location='cpu')
    #
    #     checkpoint_model = checkpoint['model']
    #
    #     state_dict = model.state_dict()
    #     for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
    #         if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
    #             print(f"Removing key {k} from pretrained checkpoint")
    #             del checkpoint_model[k]
    #
    #     _ = model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    # if args.model_ema:
    #     # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
    #     model_ema = ModelEma(
    #         model,
    #         decay=args.model_ema_decay,
    #         device='cpu' if args.model_ema_force_cpu else '',
    #         resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    criterion = DistillationLoss(criterion, None, 'none', 0, 0)

    output_dir = Path(args.output_dir)

    # for finetune
    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            if not os.path.exists(args.finetune):
                checkpoint = None
                print('NOTICE:' + args.finetune + ' does not exist!')
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')

        if checkpoint is not None:
            if 'model' in checkpoint:
                check_model = checkpoint['model']
            else:
                check_model = checkpoint

            missing_keys = model_without_ddp.load_state_dict(
                check_model, strict=False).missing_keys
            skip_keys = model_without_ddp.no_weight_decay()
            # create optimizer manually
            param_dicts = [
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n in missing_keys and n not in skip_keys
                    ],
                    "lr":
                    args.lr,
                    'weight_decay':
                    args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n in missing_keys and n in skip_keys
                    ],
                    "lr":
                    args.lr,
                    'weight_decay':
                    0,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n not in missing_keys and n not in skip_keys
                    ],
                    "lr":
                    args.lr * args.fine_factor,
                    'weight_decay':
                    args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in model_without_ddp.named_parameters()
                        if n not in missing_keys and n in skip_keys
                    ],
                    "lr":
                    args.lr * args.fine_factor,
                    'weight_decay':
                    0,
                },
            ]

            optimizer = torch.optim.AdamW(param_dicts,
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
            loss_scaler = NativeScaler()
            lr_scheduler, _ = create_scheduler(args, optimizer)

            # if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            #     optimizer.load_state_dict(checkpoint['optimizer'])
            #     lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            #     args.start_epoch = checkpoint['epoch'] + 1
            #     # if args.model_ema:
            #     #     utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            #     if 'scaler' in checkpoint:
            #         loss_scaler.load_state_dict(checkpoint['scaler'])

            print('finetune from' + args.finetune)

            # for debug
            # lr_scheduler.step(10)
            # lr_scheduler.step(100)
            # lr_scheduler.step(200)

    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            if not os.path.exists(args.resume):
                checkpoint = None
                print('NOTICE:' + args.resume + ' does not exist!')
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')

        if checkpoint is not None:
            if 'model' in checkpoint:
                model_without_ddp.load_state_dict(checkpoint['model'])
            else:
                model_without_ddp.load_state_dict(checkpoint)

            if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                args.start_epoch = checkpoint['epoch'] + 1
                # if args.model_ema:
                #     utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
                if 'scaler' in checkpoint:
                    loss_scaler.load_state_dict(checkpoint['scaler'])

            print('resume from' + args.resume)

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    max_epoch_dp_warm_up = 100
    if 'pvt_tiny' in args.model or 'pvt_small' in args.model:
        max_epoch_dp_warm_up = 0
    if args.start_epoch < max_epoch_dp_warm_up:
        model_without_ddp.reset_drop_path(0.0)
    for epoch in range(args.start_epoch, args.epochs):
        if args.fp32_resume and epoch > args.start_epoch + 1:
            args.fp32_resume = False
        loss_scaler._scaler = torch.cuda.amp.GradScaler(
            enabled=not args.fp32_resume)

        if epoch == max_epoch_dp_warm_up:
            model_without_ddp.reset_drop_path(args.drop_path)

        if epoch < args.warmup_epochs:
            optimizer.param_groups[2]['lr'] = 0
            optimizer.param_groups[3]['lr'] = 0

        if args.distributed:
            # data_loader_train.sampler.set_epoch(epoch)
            sampler_train.set_epoch(epoch)

        train_stats = my_train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            # set_training_mode=args.finetune == '',  # keep in eval mode during finetuning
            fp32=args.fp32_resume)

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        # 'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    },
                    checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 8
0
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(
        args, device)
    gradient_accumulation_steps = torch.tensor(
        args.gradient_accumulation_steps, dtype=torch.float32).to(device)
    world_size = torch.tensor(get_world_size(), dtype=torch.float32).to(device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None
    if args.do_train:
        if is_main_process():
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER",
                         data={"batch_size_per_pu": args.train_batch_size})
            dllogger.log(step="PARAMETER",
                         data={"learning_rate": args.learning_rate})

        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 0
        training_steps = 0
        model_traced = False

        if device.type == 'cuda':
            pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            restored_data_loader = None
            if not args.resume_from_checkpoint or epoch > 0 or (
                    args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [
                    os.path.join(args.input_dir, f)
                    for f in os.listdir(args.input_dir)
                    if os.path.isfile(os.path.join(args.input_dir, f))
                    and 'training' in f
                ]
                files.sort()
                num_files = len(files)
                random.Random(args.seed + epoch).shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)
                # may not exist in all checkpoints
                epoch = checkpoint.get('epoch', 0)
                restored_data_loader = checkpoint.get('data_loader', None)

            shared_file_list = {}

            if torch.distributed.is_initialized(
            ) and get_world_size() > num_files:
                remainder = get_world_size() % num_files
                data_file = files[(f_start_id * get_world_size() + get_rank() +
                                   remainder * f_start_id) % num_files]
            else:
                data_file = files[(f_start_id * get_world_size() + get_rank())
                                  % num_files]

            previous_file = data_file

            if restored_data_loader is None:
                use_pin_memory = False if args.no_cuda or args.use_habana else True
                num_workers = 0 if args.use_habana else 4
                train_data = pretraining_dataset(data_file,
                                                 args.max_predictions_per_seq)
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(
                    train_data,
                    sampler=train_sampler,
                    batch_size=args.train_batch_size * args.n_pu,
                    num_workers=num_workers,
                    worker_init_fn=worker_init,
                    pin_memory=use_pin_memory,
                    drop_last=True)
                # shared_file_list["0"] = (train_dataloader, data_file)
            else:
                train_dataloader = restored_data_loader
                restored_data_loader = None

            overflow_buf = None
            if args.allreduce_post_accumulation:
                overflow_buf = torch.cuda.IntTensor([0])

            for f_id in range(f_start_id + 1, len(files)):

                if get_world_size() > num_files:
                    data_file = files[(f_id * get_world_size() + get_rank() +
                                       remainder * f_id) % num_files]
                else:
                    data_file = files[(f_id * get_world_size() + get_rank()) %
                                      num_files]

                previous_file = data_file

                if device.type == 'cuda':
                    dataset_future = pool.submit(create_pretraining_dataset,
                                                 data_file,
                                                 args.max_predictions_per_seq,
                                                 shared_file_list, args,
                                                 worker_init)

                train_iter = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.disable_progress_bar
                                  ) if is_main_process() else train_dataloader

                if raw_train_start is None:
                    raw_train_start = time.time()
                for step, batch in enumerate(train_iter):

                    training_steps += 1
                    position_ids = compute_position_ids(batch[0])
                    if torch.distributed.is_initialized():
                        torch.distributed.barrier()
                    if args.use_habana:
                        batch = [t.to(dtype=torch.int32) for t in batch]
                        position_ids = position_ids.to(dtype=torch.int32)

                    position_ids = position_ids.to(device)
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch

                    if args.use_jit_trace:
                        if model_traced == False:
                            model = torch.jit.trace(model,
                                                    (input_ids, segment_ids,
                                                     input_mask, position_ids),
                                                    check_trace=False)
                            model_traced = True
                            if args.local_rank != -1 and not args.allreduce_post_accumulation:
                                if args.use_habana:
                                    model = DDP(model)
                                else:
                                    model = DDP(model,
                                                message_size=250000000,
                                                gradient_predivide_factor=
                                                get_world_size())
                        if args.local_rank != -1 and not args.allreduce_post_accumulation \
                                and (training_steps % args.gradient_accumulation_steps != 0):
                            with model.no_sync():
                                prediction_scores, seq_relationship_score = model(
                                    input_ids, segment_ids, input_mask,
                                    position_ids)
                        else:
                            prediction_scores, seq_relationship_score = model(
                                input_ids, segment_ids, input_mask,
                                position_ids)
                    else:
                        if args.local_rank != -1 and not args.allreduce_post_accumulation \
                                and (training_steps % args.gradient_accumulation_steps != 0):
                            with model.no_sync():
                                prediction_scores, seq_relationship_score = model(
                                    input_ids=input_ids,
                                    token_type_ids=segment_ids,
                                    attention_mask=input_mask,
                                    position_ids=position_ids)
                        else:
                            prediction_scores, seq_relationship_score = model(
                                input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                position_ids=position_ids)

                    loss = criterion(prediction_scores, seq_relationship_score,
                                     masked_lm_labels, next_sentence_labels)
                    if args.n_pu > 1:
                        loss = loss.mean()  # mean() to average on multi-pu.

                    divisor = args.gradient_accumulation_steps
                    if args.gradient_accumulation_steps > 1:
                        if not args.allreduce_post_accumulation:
                            # this division was merged into predivision
                            loss = loss / gradient_accumulation_steps
                            divisor = 1.0
                    if args.fp16:
                        with amp.scale_loss(
                                loss,
                                optimizer,
                                delay_overflow_check=args.
                                allreduce_post_accumulation) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(
                            training_steps /
                            args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).to(device)
                        if (torch.distributed.is_initialized()):
                            average_loss /= world_size
                            torch.distributed.all_reduce(average_loss)
                        final_loss = average_loss.item()
                        if is_main_process():
                            dllogger.log(step=(
                                epoch,
                                global_step,
                            ),
                                         data={"final_loss": final_loss})
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(
                                step=(
                                    epoch,
                                    global_step,
                                ),
                                data={
                                    "average_loss":
                                    average_loss / (args.log_freq * divisor),
                                    "step_loss":
                                    loss.item() *
                                    args.gradient_accumulation_steps / divisor,
                                    "learning_rate":
                                    optimizer.param_groups[0]['lr']
                                })
                        average_loss = 0

                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.
                            gradient_accumulation_steps) == 0 or timeout_sent:
                        if is_main_process() and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER",
                                         data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step))
                            checkpoint_dict = {}
                            if args.do_train:
                                if args.use_habana:
                                    config = modeling.BertConfig.from_json_file(
                                        args.config_file)

                                    # Padding for divisibility by 8
                                    if config.vocab_size % 8 != 0:
                                        config.vocab_size += 8 - (
                                            config.vocab_size % 8)

                                    model_copy = modeling.BertForPreTraining(
                                        config)
                                    model_copy.load_state_dict(
                                        model_to_save.state_dict())

                                    param_groups_copy = optimizer.state_dict(
                                    )['param_groups']
                                    state_dict_copy = {}
                                    for st_key, st_val in optimizer.state_dict(
                                    )['state'].items():
                                        st_val_copy = {}
                                        for k, v in st_val.items():
                                            if isinstance(v, torch.Tensor):
                                                st_val_copy[k] = v.to('cpu')
                                            else:
                                                st_val_copy[k] = v
                                            state_dict_copy[
                                                st_key] = st_val_copy
                                    optim_dict = {}
                                    optim_dict['state'] = state_dict_copy
                                    optim_dict[
                                        'param_groups'] = param_groups_copy
                                    checkpoint_dict = {
                                        'model':
                                        model_copy.state_dict(),
                                        'optimizer':
                                        optim_dict,
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }
                                elif no_cuda:
                                    checkpoint_dict = {
                                        'model':
                                        model_to_save.state_dict(),
                                        'optimizer':
                                        optimizer.state_dict(),
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }
                                else:
                                    checkpoint_dict = {
                                        'model':
                                        model_to_save.state_dict(),
                                        'optimizer':
                                        optimizer.state_dict(),
                                        'master params':
                                        list(amp.master_params(optimizer)),
                                        'files': [f_id] + files,
                                        'epoch':
                                        epoch,
                                        'data_loader':
                                        None if global_step >= args.max_steps
                                        else train_dataloader
                                    }

                                torch.save(checkpoint_dict, output_save_file)
                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3:
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed)

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw, global_step

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                if device.type == 'cuda':
                    train_dataloader, data_file = dataset_future.result(
                        timeout=None)
                else:
                    train_dataloader, data_file = create_pretraining_dataset(
                        data_file, args.max_predictions_per_seq,
                        shared_file_list, args, worker_init)

            epoch += 1
Ejemplo n.º 9
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
    )

    # TODO: finetuning

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(model, criterion, data_loader_train,
                                      optimizer, device, epoch, loss_scaler,
                                      args.clip_grad, model_ema, mixup_fn)

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'args': args,
                    }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 10
0
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None
    if is_main_process():
        dllogger.log(step="PARAMETER", data={"train_start": True})
        dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
        dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})

    most_recent_ckpts_paths = []
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    test_losses = []

    pool = ProcessPoolExecutor(1)

    # Note: We loop infinitely over epochs, termination is handled via iteration count
    while True:
        thread = None
        restored_data_loader = None
        if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
            files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
                    os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f]
            files.sort()
            num_files = len(files)
            random.Random(args.seed + epoch).shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint['files'][0]
            files = checkpoint['files'][1:]
            args.resume_from_checkpoint = False
            num_files = len(files)
            # may not exist in all checkpoints
            epoch = checkpoint.get('epoch', 0)
            restored_dataloader = checkpoint.get('data_loader', None)

        shared_file_list = {}

        if torch.distributed.is_initialized() and get_world_size() > num_files:
            remainder = get_world_size() % num_files
            data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
        else:
            data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]

        previous_file = data_file

        if restored_data_loader is None:
            train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                        batch_size=args.train_batch_size * args.n_gpu,
                                        num_workers=4, worker_init_fn=worker_init,
                                        pin_memory=True)
            # shared_file_list["0"] = (train_dataloader, data_file)
        else:
            train_dataloader = restored_data_loader
            restored_data_loader = None

        overflow_buf = None
        if args.allreduce_post_accumulation:
            overflow_buf = torch.cuda.IntTensor([0])

        for f_id in range(f_start_id + 1 , len(files)):
            if get_world_size() > num_files:
                data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
            else:
                data_file = files[(f_id*get_world_size()+get_rank())%num_files]

            previous_file = data_file

            dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)

            train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

            if raw_train_start is None:
                raw_train_start = time.time()

            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                if args.do_train:
                    from smdistributed.modelparallel.test.torch.utils import verify, dump_model
                    model.train()
                    if args.smp > 0:
                        loss_mbs = smp_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step)
                        loss = loss_mbs.reduce_mean()
                        if smp.rank() == 0:
                            print("Loss:", loss.item())
                    else:
                        loss = train_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step)
                    divisor=1
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps * divisor)
                        if (torch.distributed.is_initialized()):
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = loss.item()
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss})
                    elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),
                                                                            "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                                                            "learning_rate": optimizer.param_groups[0]['lr']})
                        average_loss = 0


                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
                        if smp.dp_rank() == 0 and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
                            # model_to_save = model.module if hasattr(model,
                            #                                         'module') else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
                            if args.do_train:
                                save_dict = {
                                            'model': model.local_state_dict(),
                                            'optimizer': optimizer.local_state_dict(),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
                                if args.fp16:
                                    save_dict['master params'] = list(amp.master_params(optimizer))
                                # SMP: Checkpoint mp_rank specific state
                                smp.save(save_dict, output_save_file, partial=True)

                                most_recent_ckpts_paths.append(output_save_file)
                                if len(most_recent_ckpts_paths) > 3 and (args.smp == 0 or smp.dp_rank() == 0):
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                                    os.remove(ckpt_to_be_removed+f"_{smp.mp_rank()}")

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            if smp.dp_rank() == 0 and args.save_full:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                                save_dict = {
                                            'model': model.local_state_dict(),
                                            'optimizer': optimizer.local_state_dict(),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.steps_this_run else train_dataloader}
                                if args.fp16:
                                    save_dict['master params'] = list(amp.master_params(optimizer))
                                # SMP: Save a single checkpoint containing entire model parameters
                                smp.save(save_dict, output_save_file, partial=False)
                            smp.barrier()
                            return args, final_loss, train_time_raw, global_step
                else:
                    model.eval()
                    with torch.no_grad():
                        loss = test_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step)
                        print(f"global_step {global_step} Test Loss:", loss)
                        test_losses.append(loss)
                    global_step += 1
                    if global_step >= args.steps_this_run:
                        return sum(test_losses) / len(test_losses)

            del train_dataloader
            # thread.join()
            # Make sure pool has finished and switch train_dataloader
            # NOTE: Will block until complete
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1
Ejemplo n.º 11
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(
                args.output_dir, "ckpt_{}.pt".format(global_step)),
                                    map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    # BERT modeling  uses weight sharing between word embedding and prediction decoder.
    # So make sure the storage is pointing properly even after model is moved to device.
    if args.use_habana:
        model.cls.predictions.decoder.weight = model.bert.embeddings.word_embeddings.weight

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.use_habana:
        if args.use_fused_lamb:
            try:
                from hb_custom import FusedLamb
            except ImportError:
                raise ImportError("Please install hbopt.")
            optimizer = FusedLamb(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)
    else:
        if torch.cuda.is_available():
            optimizer = FusedLAMB(optimizer_grouped_parameters,
                                  lr=args.learning_rate)
        else:
            optimizer = NVLAMB(optimizer_grouped_parameters,
                               lr=args.learning_rate)

    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale="dynamic",
                                              cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              loss_scale=args.loss_scale,
                                              cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(
                    checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter][
                    'step'] = global_step
                checkpoint['optimizer']['param_groups'][iter][
                    't_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter][
                    'warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter][
                    'lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            if not args.use_jit_trace:
                if args.use_habana:
                    model = DDP(model)
                else:
                    model = DDP(model,
                                message_size=250000000,
                                gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()],
                           torch.distributed.broadcast, (0, ))
    elif args.n_pu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Ejemplo n.º 12
0
def main():
    global timeout_sent

    args = parse_arguments()
        
    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)

    if args.disable_weight_tying:
       # Sanity Check that new param is in optimizer
       print ("SANITY CHECK OPTIMIZER: ", id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']])
       assert id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']]

    print (f"SAVING EVERY {args.num_steps_per_checkpoint} STEPS!")

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None


    if args.do_train:
        if is_main_process():
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
            dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})

        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        epoch = 0
        training_steps = 0

        pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            restored_data_loader = None
            if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
                         os.path.isfile(os.path.join(args.input_dir, f)) and ('training' in f or 'train' in f)]
                files.sort()
                num_files = len(files)
                random.Random(args.seed + epoch).shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)
                # may not exist in all checkpoints
                epoch = checkpoint.get('epoch', 0)
                restored_data_loader = checkpoint.get('data_loader', None)

            shared_file_list = {}

            if torch.distributed.is_initialized() and get_world_size() > num_files:
                remainder = get_world_size() % num_files
                data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
            else:
                data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]

            previous_file = data_file

            if restored_data_loader is None:
                train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                              batch_size=args.train_batch_size * args.n_gpu,
                                              num_workers=4, worker_init_fn=worker_init,
                                              pin_memory=True)
                # shared_file_list["0"] = (train_dataloader, data_file)
            else:
                train_dataloader = restored_data_loader
                restored_data_loader = None

            overflow_buf = None
            if args.allreduce_post_accumulation:
                overflow_buf = torch.cuda.IntTensor([0])

            for f_id in range(f_start_id + 1 , len(files)):
                
   
                if get_world_size() > num_files:
                    data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
                else:
                    data_file = files[(f_id*get_world_size()+get_rank())%num_files]

                previous_file = data_file

                dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)

                train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

                if raw_train_start is None:
                    raw_train_start = time.time()
                for step, batch in enumerate(train_iter):
                    training_steps += 1
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                    prediction_scores, seq_relationship_score = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
                    loss, mlm_loss, ns_loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels)
                    if args.n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                        mlm_loss = mlm_loss.detach().mean()
                        ns_loss = ns_loss.detach().mean()

                    divisor = args.gradient_accumulation_steps
                    if args.gradient_accumulation_steps > 1:
                        if not args.allreduce_post_accumulation:
                            # this division was merged into predivision
                            loss = loss / args.gradient_accumulation_steps
                            mlm_loss = mlm_loss.detach() / args.gradient_accumulation_steps
                            ns_loss = ns_loss.detach() / args.gradient_accumulation_steps
                            divisor = 1.0
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps * divisor)
                        if (torch.distributed.is_initialized()):
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = average_loss.item()
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss})
                    elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),
                                                                            "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                                                            "learning_rate": optimizer.param_groups[0]['lr'],
                                                                            "mlm_loss" : mlm_loss.item(),
                                                                            "ns_loss" : ns_loss.item()})
                        average_loss = 0


                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
                        if is_main_process() and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(model,
                                                                    'module') else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
                            if args.do_train:
                                torch.save({'model': model_to_save.state_dict(),
                                            'optimizer': optimizer.state_dict(),
                                            'master params': list(amp.master_params(optimizer)),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.max_steps else train_dataloader}, output_save_file)

                                most_recent_ckpts_paths.append(output_save_file)

                        # Exiting the training due to hitting max steps, or being sent a 
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw, global_step

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(timeout=None)

            epoch += 1
Ejemplo n.º 13
0
def main(e2e_start_time):
    # Parse essential argumentss
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", required=True)
    parser.add_argument("--model_size",
                        default="base",
                        type=str,
                        help="base or large")
    parser.add_argument("--pretrain_tfrecords", type=str)
    parser.add_argument("--phase2", action='store_true')
    parser.add_argument("--fp16_compression", action='store_true')
    parser.add_argument("--amp",
                        action='store_true',
                        help="Whether to use fp16.")
    parser.add_argument("--xla",
                        action='store_true',
                        help="Whether to use xla.")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--num_train_steps", type=int)
    parser.add_argument("--num_warmup_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--train_batch_size", type=int)
    parser.add_argument("--max_seq_length", type=int)

    parser.add_argument("--mask_prob", type=float)
    parser.add_argument("--disc_weight", type=float)
    parser.add_argument("--generator_hidden_size", type=float)

    parser.add_argument("--log_freq",
                        type=int,
                        default=10,
                        help="Training metrics logging frequency")
    parser.add_argument("--save_checkpoints_steps", type=int)
    parser.add_argument("--keep_checkpoint_max", type=int)
    parser.add_argument("--restore_checkpoint", default=None, type=str)
    parser.add_argument("--load_weights", action='store_true')
    parser.add_argument("--weights_dir")

    parser.add_argument("--optimizer",
                        default="adam",
                        type=str,
                        help="adam or lamb")
    parser.add_argument(
        "--skip_adaptive",
        action='store_true',
        help="Whether to apply adaptive LR on LayerNorm and biases")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Number of Gradient Accumulation steps")
    parser.add_argument("--lr_decay_power",
                        type=float,
                        default=0.5,
                        help="LR decay power")
    parser.add_argument("--opt_beta_1",
                        type=float,
                        default=0.878,
                        help="Optimizer beta1")
    parser.add_argument("--opt_beta_2",
                        type=float,
                        default=0.974,
                        help="Optimizer beta2")
    parser.add_argument("--end_lr", type=float, default=0.0, help="Ending LR")
    parser.add_argument("--log_dir",
                        type=str,
                        default=None,
                        help="Path to store logs")
    parser.add_argument("--results_dir",
                        type=str,
                        default=None,
                        help="Path to store all model results")
    parser.add_argument("--skip_checkpoint",
                        action='store_true',
                        default=False,
                        help="Path to store logs")
    parser.add_argument(
        '--json-summary',
        type=str,
        default=None,
        help=
        'If provided, the json summary will be written to the specified file.')
    args = parser.parse_args()
    config = PretrainingConfig(**args.__dict__)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # Set up tensorflow
    hvd.init()

    args.log_dir = config.log_dir
    # DLLogger
    setup_logger(args)

    set_affinity(hvd.local_rank())
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')
    tf.config.optimizer.set_jit(config.xla)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": config.amp})

    if config.amp:
        policy = tf.keras.mixed_precision.experimental.Policy(
            "mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' %
              policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' %
              policy.variable_dtype)  # Variable dtype: float32

    #tf.random.set_seed(config.seed)

    # Set up config cont'
    if config.load_weights and config.restore_checkpoint:
        raise ValueError(
            "`load_weights` and `restore_checkpoint` should not be on at the same time."
        )
    if config.phase2 and not config.restore_checkpoint:
        raise ValueError(
            "`phase2` cannot be used without `restore_checkpoint`.")
    utils.heading("Config:")
    log_config(config)

    # Save pretrain configs
    pretrain_config_json = os.path.join(config.checkpoints_dir,
                                        'pretrain_config.json')
    if is_main_process():
        utils.write_json(config.__dict__, pretrain_config_json)
        log("Configuration saved in {}".format(pretrain_config_json))

    # Set up model
    model = PretrainingModel(config)

    # Set up metrics
    metrics = dict()
    metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")
    metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
    metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
        name="masked_lm_accuracy")
    metrics["masked_lm_loss"] = tf.keras.metrics.Mean(name="masked_lm_loss")
    if config.electra_objective:
        metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
            name="sampled_masked_lm_accuracy")
        if config.disc_weight > 0:
            metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
            metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
            metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(
                name="disc_accuracy")
            metrics["disc_precision"] = tf.keras.metrics.Accuracy(
                name="disc_precision")
            metrics["disc_recall"] = tf.keras.metrics.Accuracy(
                name="disc_recall")

    # Set up tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(
        config.log_dir, current_time,
        'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Set up dataset
    dataset = pretrain_utils.get_dataset(config,
                                         config.train_batch_size,
                                         world_size=get_world_size(),
                                         rank=get_rank())
    train_iterator = iter(dataset)

    # Set up optimizer
    optimizer = create_optimizer(init_lr=config.learning_rate,
                                 num_train_steps=config.num_train_steps,
                                 num_warmup_steps=config.num_warmup_steps,
                                 weight_decay_rate=config.weight_decay_rate,
                                 optimizer=config.optimizer,
                                 skip_adaptive=config.skip_adaptive,
                                 power=config.lr_decay_power,
                                 beta_1=config.opt_beta_1,
                                 beta_2=config.opt_beta_2,
                                 end_lr=config.end_lr)

    accumulator = GradientAccumulator()
    if config.amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, "dynamic")

    # Set up model checkpoint
    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     phase2=tf.Variable(False),
                                     optimizer=optimizer,
                                     model=model)
    manager = tf.train.CheckpointManager(
        checkpoint,
        config.checkpoints_dir,
        max_to_keep=config.keep_checkpoint_max)
    if config.restore_checkpoint and config.restore_checkpoint != "latest":
        checkpoint.restore(config.restore_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            config.restore_checkpoint))
    elif config.restore_checkpoint and config.restore_checkpoint == "latest" and manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        log(" ** Restored model checkpoint from {}".format(
            manager.latest_checkpoint))
    elif config.load_weights:
        model.generator(model.generator.dummy_inputs)
        model.discriminator(model.discriminator.dummy_inputs)
        model.generator.load_weights(
            os.path.join(config.weights_dir, 'generator', 'tf_model.h5'))
        model.discriminator.load_weights(
            os.path.join(config.weights_dir, 'discriminator', 'tf_model.h5'))
    else:
        log(" ** Initializing from scratch.")

    restore_iterator = bool(
        config.restore_checkpoint) and config.restore_checkpoint == "latest"
    # Initialize global step for phase2
    if config.phase2 and not bool(checkpoint.phase2):
        optimizer.iterations.assign(0)
        checkpoint.step.assign(0)
        checkpoint.phase2.assign(True)
        restore_iterator = False
    if bool(checkpoint.phase2):
        manager = tf.train.CheckpointManager(
            checkpoint,
            config.checkpoints_dir,
            checkpoint_name='ckpt-p2',
            max_to_keep=config.keep_checkpoint_max)

    # Set up iterator checkpoint
    iter_checkpoint = tf.train.Checkpoint(train_iterator=train_iterator,
                                          world_size=tf.Variable(
                                              get_world_size()),
                                          rank=tf.Variable(get_rank()))
    iter_manager = tf.train.CheckpointManager(
        iter_checkpoint,
        os.path.join(config.checkpoints_dir,
                     'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
        checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
        max_to_keep=config.keep_checkpoint_max)
    if restore_iterator and iter_manager.latest_checkpoint:
        ckpt_world_size = tf.train.load_variable(
            iter_manager.latest_checkpoint,
            'world_size/.ATTRIBUTES/VARIABLE_VALUE')
        if ckpt_world_size == get_world_size():
            iter_checkpoint.restore(iter_manager.latest_checkpoint)
            log(" ** Restored iterator checkpoint from {}".format(
                iter_manager.latest_checkpoint),
                all_rank=True)

    utils.heading("Running training")
    accumulator.reset()
    train_start, start_step = time.time(), int(checkpoint.step) - 1
    local_step = 0
    saved_ckpt = False
    while int(checkpoint.step) <= config.num_train_steps:
        saved_ckpt = False
        step = int(checkpoint.step)
        features = next(train_iterator)
        iter_start = time.time()

        # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
        total_loss, eval_fn_inputs = train_one_step(
            config,
            model,
            optimizer,
            features,
            accumulator,
            local_step == 1,
            take_step=local_step % args.gradient_accumulation_steps == 0)
        # if step == 300: tf.profiler.experimental.stop()

        metrics["train_perf"].update_state(config.train_batch_size *
                                           get_world_size() /
                                           (time.time() - iter_start))
        metrics["total_loss"].update_state(values=total_loss)
        metric_fn(config, metrics, eval_fn_inputs)

        if (step % args.log_freq
                == 0) and (local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(v.result().numpy() *
                         100) if "accuracy" in k else float(v.result().numpy())
                for k, v in metrics.items()
            }
            dllogger.log(step=(step, ), data=log_info_dict, verbosity=0)
            log('Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f}, Loss Scaler: {loss_scale}, Elapsed: {elapsed}, ETA: {eta}, '
                .format(step=step,
                        **log_info_dict,
                        loss_scale=optimizer.loss_scale if config.amp else 1,
                        elapsed=utils.get_readable_time(time.time() -
                                                        train_start),
                        eta=utils.get_readable_time(
                            (time.time() - train_start) / (step - start_step) *
                            (config.num_train_steps - step))),
                all_rank=True)

            with train_summary_writer.as_default():
                for key, m in metrics.items():
                    tf.summary.scalar(key, m.result(), step=step)

            if int(checkpoint.step) < config.num_train_steps:
                for m in metrics.values():
                    m.reset_states()

        #Print allreduced metrics on the last step
        if int(checkpoint.step) == config.num_train_steps and (
                local_step % args.gradient_accumulation_steps == 0):
            log_info_dict = {
                k: float(hvd.allreduce(v.result()).numpy() * 100) if "accuracy"
                in k else float(hvd.allreduce(v.result()).numpy())
                for k, v in metrics.items()
            }
            log_info_dict["training_sequences_per_second"] = log_info_dict[
                "train_perf"]
            log_info_dict["final_loss"] = log_info_dict["total_loss"]
            log_info_dict["e2e_train_time"] = time.time() - e2e_start_time
            dllogger.log(step=(), data=log_info_dict, verbosity=0)
            log('<FINAL STEP METRICS> Step:{step:6d}, Loss:{total_loss:10.6f}, Gen_loss:{masked_lm_loss:10.6f}, Disc_loss:{disc_loss:10.6f}, Gen_acc:{masked_lm_accuracy:6.2f}, '
                'Disc_acc:{disc_accuracy:6.2f}, Perf:{train_perf:4.0f},'.
                format(step=step, **log_info_dict),
                all_rank=False)

        if local_step % args.gradient_accumulation_steps == 0:
            checkpoint.step.assign(int(optimizer.iterations))

        local_step += 1
        if not config.skip_checkpoint and (
                local_step %
            (config.save_checkpoints_steps * args.gradient_accumulation_steps)
                == 0):
            saved_ckpt = True
            if is_main_process():
                save_path = manager.save(checkpoint_number=step)
                log(" ** Saved model checkpoint for step {}: {}".format(
                    step, save_path))
            iter_save_path = iter_manager.save(checkpoint_number=step)
            log(" ** Saved iterator checkpoint for step {}: {}".format(
                step, iter_save_path),
                all_rank=True)

    step = (int(checkpoint.step) - 1)
    dllogger.flush()
    if not config.skip_checkpoint and not saved_ckpt:
        if is_main_process():
            save_path = manager.save(checkpoint_number=step)
            log(" ** Saved model checkpoint for step {}: {}".format(
                step, save_path))
        iter_save_path = iter_manager.save(checkpoint_number=step)
        log(" ** Saved iterator checkpoint for step {}: {}".format(
            step, iter_save_path),
            all_rank=True)

    return args
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, grad_scaler, lr_scheduler, checkpoint, global_step, criterion, epoch = prepare_model_and_optimizer(args, device, sequence_output_is_dense=not args.no_dense_sequence_output)
    # Prepare the data loader.
    if is_main_process():
        tic = time.perf_counter()
    train_dataloader = lddl.torch.get_bert_pretrain_data_loader(
        args.input_dir,
        local_rank=max(args.local_rank, 0),
        vocab_file=args.vocab_file,
        data_loader_kwargs={
            'batch_size': args.train_batch_size * args.n_gpu,
            'num_workers': args.num_workers,
            'pin_memory': True,
        },
        base_seed=args.seed,
        log_dir=None if args.output_dir is None else os.path.join(args.output_dir, 'lddl_log'),
        log_level=logging.WARNING,
        start_epoch=epoch,
    )
    if is_main_process():
        print('get_bert_pretrain_data_loader took {} s!'.format(time.perf_counter() - tic))

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})
        dllogger.log(step="PARAMETER", data={"train_start": True})
        dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
        dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})

    model.train()
    most_recent_ckpts_paths = []

    stats = SyncFreeStats()
    # Host Only Stats
    stats.add_stat('model_step')
    # Device/Host Sync-ed Stats
    stats.add_stat('optimizer_step', dtype=torch.int32, device_func=(lambda: optimizer.param_groups[0]['step']))
    stats.add_stat('average_loss', dtype=torch.float32, device_tensor=torch.zeros(1, dtype=torch.float32, device=device))
    stats.add_stat('learning_rate', dtype=torch.float32, device_func=(lambda: optimizer.param_groups[0]['lr']))
    if grad_scaler.is_enabled():
        # This stat only indicates a skipped step occurred.  It does not accumulate the number of skipped steps
        stats.add_stat('skip_optimizer_step', dtype=torch.float32, device_func=(lambda: grad_scaler._found_inf_per_device(optimizer)[device]))
        stats.add_stat('skipped_optimizer_steps', dtype=torch.float32, device_tensor=torch.zeros(1, dtype=torch.float32, device=device),
                                                  device_func=(lambda x: x.add_(grad_scaler._found_inf_per_device(optimizer)[device])))
    else:
        stats.add_stat('skip_optimizer_step', dtype=torch.float32)
        stats.add_stat('skipped_optimizer_steps', dtype=torch.float32)

    static_gpu_batch = None
    full_cudagraph = None
    grad_accum_cudagraph = None
    if args.cuda_graphs:
        static_gpu_batch = {
            'input_ids': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
            'token_type_ids': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
            'attention_mask': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
            'labels': torch.ones(args.train_batch_size, args.max_seq_length, dtype=torch.int64, device=device),
            'next_sentence_labels': torch.ones(args.train_batch_size, dtype=torch.int64, device=device),
        }

        side_stream = torch.cuda.Stream()

        # Warmup Steps - includes jitting fusions
        side_stream = torch.cuda.Stream()
        side_stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(side_stream):
            for _ in range(11):
                take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
                take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)
        torch.cuda.current_stream().wait_stream(side_stream)

        # Capture Graph
        full_cudagraph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(full_cudagraph):
            take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
            take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)

        # Warmup Steps - includes jitting fusions
        side_stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(side_stream):
            for _ in range(3):
                with model.no_sync():
                    take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)
        torch.cuda.current_stream().wait_stream(side_stream)

        # Capture Graph
        grad_accum_cudagraph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(grad_accum_cudagraph):
            with model.no_sync():
                take_training_step(args, grad_scaler, model, criterion, static_gpu_batch, stats)

    train_iter = tqdm(
        train_dataloader,
        desc="Iteration",
        disable=args.disable_progress_bar,
        total=len(train_dataloader),
    ) if is_main_process() else train_dataloader


    raw_train_start = None

    # avoid nvfuser compilation times in measuring perf with phase2 binning
    # ideally skip > 3 * num_bins fwd+bwd iterations to start measuring perf 
    skip_fwd_bwd_for_perf = 4
    if args.phase2: #we use 8 bins with phase2
        skip_fwd_bwd_for_perf = 50 

    while True:
        for step, batch in enumerate(train_iter):
            # The first training step is 1 and not 0 when gradient accumulating
            # in order to avoid an optimizer step on the very first step
            stats.host_stat('model_step').add_(1)
            grad_accumulation_step = (stats.host_stat_value('model_step') % args.gradient_accumulation_steps) != 0

            if raw_train_start is None and step == skip_fwd_bwd_for_perf:
                raw_train_start = time.time()

            # Execute Model Step
            if args.cuda_graphs:
                for k in batch.keys():
                    static_gpu_batch[k].copy_(batch[k], non_blocking=True)
                if grad_accumulation_step:
                    grad_accum_cudagraph.replay()
                else:
                    full_cudagraph.replay()
            else:
                batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

                if args.allreduce_post_accumulation and grad_accumulation_step:
                    with model.no_sync():
                        take_training_step(args, grad_scaler, model, criterion, batch, stats)
                else:
                    take_training_step(args, grad_scaler, model, criterion, batch, stats)

                if not grad_accumulation_step:
                    take_optimizer_step(args, lr_scheduler, optimizer, grad_scaler, device, stats)

            # Log Optimizer Step
            if (not grad_accumulation_step) or timeout_sent:
                static_optimizer_step = stats.host_stat_value('model_step') // args.gradient_accumulation_steps
                dynamic_optimizer_step = static_optimizer_step - int(stats.host_stat_value('skipped_optimizer_steps'))
                no_log_steps = static_optimizer_step % args.log_freq

                # Log Final Step (MAYBE)
                # Since the stats are asynchronously pushed from the GPU to CPU, they are not always reliable
                # Therefore, a synchronization is required to guarantee you see the intended value.
                # Without a synchronization, it is possible for some GPUs to go through the exit conditional
                # and others to not because they accidentally see a different value for `skipped_optimizer_steps`.
                # In order to remove most device syncs, synchronizations only begin in the last few steps
                # where the skipped step count matters.
                if static_optimizer_step >= args.steps_this_run or timeout_sent:
                    torch.cuda.synchronize()
                    dynamic_optimizer_step = static_optimizer_step - int(stats.host_stat_value('skipped_optimizer_steps'))
                    if dynamic_optimizer_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start

                        last_num_steps = args.log_freq if no_log_steps == 0 else no_log_steps
                        stats.device_stat('average_loss').div_(last_num_steps * args.gradient_accumulation_steps)
                        if (torch.distributed.is_initialized()):
                            stats.device_stat('average_loss').div_(get_world_size())
                            torch.distributed.all_reduce(stats.device_stat('average_loss'))

                        # We block on this copy to insure the final value
                        stats.host_stat('average_loss').copy_(stats.device_stat('average_loss'))
                        if is_main_process():
                            dllogger.log(step=(epoch, dynamic_optimizer_step,), data={"final_loss": stats.host_stat_value('average_loss')})

                        checkpoint_step(args, epoch, dynamic_optimizer_step, model, optimizer, grad_scaler, most_recent_ckpts_paths)

                        return args, train_time_raw, stats, skip_fwd_bwd_for_perf

                if no_log_steps == 0:
                    if is_main_process():
                        dllogger.log(step=(epoch, dynamic_optimizer_step,),
                                     data={"average_loss": stats.host_stat_value('average_loss') / (args.log_freq * args.gradient_accumulation_steps),
                                           "learning_rate": stats.host_stat_value('learning_rate'),
                                           "skipped_steps": int(stats.host_stat_value('skipped_optimizer_steps'))})
                        if stats.host_stat_value('skip_optimizer_step') > 0.:
                            dllogger.log(step="PARAMETER", data={"loss_scale": grad_scaler._get_scale_async().item()})

                    stats.device_stat('average_loss').zero_()

                    if not args.skip_checkpoint and (dynamic_optimizer_step % args.num_steps_per_checkpoint == 0):
                        checkpoint_step(args, epoch, dynamic_optimizer_step, model, optimizer, grad_scaler, most_recent_ckpts_paths)

        epoch += 1
Ejemplo n.º 15
0
def eval_linear(args):
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    train_transform = pth_transforms.Compose([
        pth_transforms.RandomResizedCrop(224),
        pth_transforms.RandomHorizontalFlip(),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    val_transform = pth_transforms.Compose([
        pth_transforms.Resize(256, interpolation=3),
        pth_transforms.CenterCrop(224),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
    dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

    # ============ building network ... ============
    model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
    model.cuda()
    model.eval()
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
    # load weights to evaluate
    utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

    linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels)
    linear_classifier = linear_classifier.cuda()
    linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

    # set optimizer
    optimizer = torch.optim.SGD(
        linear_classifier.parameters(),
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
        momentum=0.9,
        weight_decay=0, # we do not apply weight decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)

    # Optionally resume from a checkpoint
    to_restore = {"epoch": 0, "best_acc": 0.}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=linear_classifier,
        optimizer=optimizer,
        scheduler=scheduler,
    )
    start_epoch = to_restore["epoch"]
    best_acc = to_restore["best_acc"]

    for epoch in range(start_epoch, args.epochs):
        train_loader.sampler.set_epoch(epoch)

        train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
        scheduler.step()

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
            test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
            print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
            best_acc = max(best_acc, test_stats["acc1"])
            print(f'Max accuracy so far: {best_acc:.2f}%')
            log_stats = {**{k: v for k, v in log_stats.items()},
                         **{f'test_{k}': v for k, v in test_stats.items()}}
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": linear_classifier.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "best_acc": best_acc,
            }
            torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
    print("Training of the supervised linear classifier on frozen features completed.\n"
                "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))
Ejemplo n.º 16
0
def main():
    global timeout_sent

    args = parse_arguments()
        
    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)
    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    # Prepare optimizer
    model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device)

    if is_main_process():
        dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    raw_train_start = None
    if args.do_train:
        if is_main_process():
            dllogger.log(step="PARAMETER", data={"train_start": True})
            dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size})
            dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate})

        model.train()
        most_recent_ckpts_paths = []
        average_loss = 0.0  # averaged loss every args.log_freq steps
        ave_mask_acc=0.0
        ave_cgp_acc=0.0
        epoch = 0
        training_steps = 0

        pool = ProcessPoolExecutor(1)

        # Note: We loop infinitely over epochs, termination is handled via iteration count
        while True:
            thread = None
            restored_data_loader = None
            if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
                files = [i for i in range(256)]
                files.sort()
                num_files = len(files)
                random.Random(args.seed + epoch).shuffle(files)
                f_start_id = 0
            else:
                f_start_id = checkpoint['files'][0]
                files = checkpoint['files'][1:]
                args.resume_from_checkpoint = False
                num_files = len(files)
                # may not exist in all checkpoints
                epoch = checkpoint.get('epoch', 0)
                restored_dataloader = checkpoint.get('data_loader', None)

            shared_file_list = {}

            if torch.distributed.is_initialized() and get_world_size() > num_files:
                remainder = get_world_size() % num_files
                data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
            else:
                data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]

            previous_file = data_file

            if restored_data_loader is None:
                train_data = PretrainDataset("data/" + args.dataset,rank=data_file)
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoaderMasking(train_data, sampler=train_sampler,
                                              batch_size=args.train_batch_size * args.n_gpu,
                                              num_workers=4, worker_init_fn=worker_init,
                                              pin_memory=True)
                # shared_file_list["0"] = (train_dataloader, data_file)
            else:
                train_dataloader = restored_data_loader
                restored_data_loader = None

            overflow_buf = None
            if args.allreduce_post_accumulation:
                overflow_buf = torch.cuda.IntTensor([0])

            for f_id in range(f_start_id + 1 , len(files)):
                
   
                if get_world_size() > num_files:
                    data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
                else:
                    data_file = files[(f_id*get_world_size()+get_rank())%num_files]

                previous_file = data_file

                dataset_future = pool.submit(create_pretraining_dataset, data_file, args, worker_init)

                train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

                if raw_train_start is None:
                    raw_train_start = time.time()
                for step, batch in enumerate(train_iter):
                    training_steps += 1
                    batch = batch.to(device)
                    pred_node, pred_graph = model(batch)
                    loss = criterion(pred_node, pred_graph,batch.mask_node_label,batch.ngp_y)
                    if args.n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.

                    divisor = args.gradient_accumulation_steps
                    if args.gradient_accumulation_steps > 1:
                        if not args.allreduce_post_accumulation:
                            # this division was merged into predivision
                            loss = loss / args.gradient_accumulation_steps
                            divisor = 1.0
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    average_loss += loss.item()
                    acc_node = compute_accuracy(pred_node, batch.mask_node_label)
                    acc_cgp = compute_accuracy(pred_graph, batch.ngp_y)
                    ave_mask_acc+=acc_node
                    ave_cgp_acc+=acc_cgp

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps * divisor)

                        ave_mask_acc = torch.tensor(ave_mask_acc, dtype=torch.float32).cuda()
                        ave_mask_acc = ave_mask_acc / (last_num_steps * divisor)
                        ave_cgp_acc = torch.tensor(ave_cgp_acc, dtype=torch.float32).cuda()
                        ave_cgp_acc = ave_cgp_acc / (last_num_steps * divisor)

                        if (torch.distributed.is_initialized()):
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)

                            ave_mask_acc/=get_world_size()
                            torch.distributed.all_reduce(ave_mask_acc)

                            ave_cgp_acc /= get_world_size()
                            torch.distributed.all_reduce(ave_cgp_acc)
                        final_loss = average_loss.item()
                        final_maskacc = ave_mask_acc.item()
                        final_cgpacc = ave_mask_acc.item()

                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss,"final_maskacc":final_maskacc,
                                                                            "final_cgpacc":final_cgpacc})
                    elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
                        if is_main_process():
                            dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor),
                                                                            "average_mask_acc":ave_mask_acc / (args.log_freq * divisor),
                                                                            "average_cgp_acc": ave_cgp_acc / (args.log_freq * divisor),
                                                                            "step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
                                                                            "learning_rate": optimizer.param_groups[0]['lr']})
                        average_loss = 0
                        ave_mask_acc = 0
                        ave_cgp_acc = 0


                    if global_step >= args.steps_this_run or training_steps % (
                            args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
                        if is_main_process() and not args.skip_checkpoint:
                            # Save a trained model
                            dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
                            model_to_save = model.module if hasattr(model,
                                                                    'module') else model  # Only save the model it-self
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step))
                            if args.do_train:
                                torch.save({'model': model_to_save.state_dict(),
                                            'gnn':model_to_save.gnn.state_dict(),
                                            'linear_atom':model_to_save.linear_pred_atoms.state_dict(),
                                            'optimizer': optimizer.state_dict(),
                                            'master params': list(amp.master_params(optimizer)),
                                            'files': [f_id] + files,
                                            'epoch': epoch,
                                            'data_loader': None if global_step >= args.max_steps else train_dataloader
                                            }, output_save_file)

                                most_recent_ckpts_paths.append(output_save_file)
                                if len(most_recent_ckpts_paths) > 3:
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                                    os.remove(ckpt_to_be_removed)

                        # Exiting the training due to hitting max steps, or being sent a 
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, final_maskacc, final_cgpacc, train_time_raw, global_step

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(timeout=None)

            epoch += 1
Ejemplo n.º 17
0
def main(args):
    args.fp16 = args.fp16 or args.amp
    if args.server_ip and args.server_port:
        # Distant debugging - see
        # https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        logger.info("Waiting for debugger attach")
        ptvsd.enable_attach(
            address=(args.server_ip, args.server_port),
            redirect_output=True,
        )
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of
        # sychronizing nodes/GPUs.
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, "
                "16-bits training: {}".format(
                    device,
                    n_gpu,
                    bool(args.local_rank != -1),
                    args.fp16,
                ))

    if not args.do_train and not args.do_eval and not args.do_predict:
        raise ValueError("At least one of `do_train`, `do_eval` or "
                         "`do_predict` must be True.")

    if is_main_process():
        if (os.path.exists(args.output_dir) and os.listdir(args.output_dir)
                and args.do_train):
            logger.warning("Output directory ({}) already exists and is not "
                           "empty.".format(args.output_dir))
    mkdir_by_main_process(args.output_dir)

    if is_main_process():
        dllogger.init(backends=[
            dllogger.JSONStreamBackend(
                verbosity=dllogger.Verbosity.VERBOSE,
                filename=os.path.join(args.output_dir, 'dllogger.json'),
            ),
            dllogger.StdOutBackend(
                verbosity=dllogger.Verbosity.VERBOSE,
                step_format=format_step,
            ),
        ])
    else:
        dllogger.init(backends=[])

    dllogger.log(step="PARAMETER", data={"Config": [str(args)]})

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             args.gradient_accumulation_steps))
    if args.gradient_accumulation_steps > args.train_batch_size:
        raise ValueError("gradient_accumulation_steps ({}) cannot be larger "
                         "train_batch_size ({}) - there cannot be a fraction "
                         "of one sample.".format(
                             args.gradient_accumulation_steps,
                             args.train_batch_size,
                         ))
    args.train_batch_size = (args.train_batch_size //
                             args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    dllogger.log(step="PARAMETER", data={"SEED": args.seed})

    processor = PROCESSORS[args.task_name]()
    num_labels = len(processor.get_labels())

    #tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = BertTokenizer(
        args.vocab_file,
        do_lower_case=args.do_lower_case,
        max_len=512,
    )  # for bert large

    num_train_optimization_steps = None
    if args.do_train:
        train_features = get_train_features(
            args.data_dir,
            args.bert_model,
            args.max_seq_length,
            args.do_lower_case,
            args.local_rank,
            args.train_batch_size,
            args.gradient_accumulation_steps,
            args.num_train_epochs,
            tokenizer,
            processor,
        )
        num_train_optimization_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = (num_train_optimization_steps //
                                            torch.distributed.get_world_size())

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    # modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForSequenceClassification(
        config,
        num_labels=num_labels,
    )
    logger.info("USING CHECKPOINT from {}".format(args.init_checkpoint))

    checkpoint = torch.load(args.init_checkpoint, map_location='cpu')
    checkpoint = checkpoint["model"] if "model" in checkpoint.keys(
    ) else checkpoint
    model.load_state_dict(checkpoint, strict=False)
    logger.info("USED CHECKPOINT from {}".format(args.init_checkpoint))
    dllogger.log(
        step="PARAMETER",
        data={
            "num_parameters":
            sum([p.numel() for p in model.parameters() if p.requires_grad]),
        },
    )

    model.to(device)
    # Prepare optimizer
    model, optimizer, scheduler = init_optimizer_and_amp(
        model,
        args.learning_rate,
        args.loss_scale,
        args.warmup_proportion,
        num_train_optimization_steps,
        args.fp16,
    )

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from "
                              "https://www.github.com/nvidia/apex to use "
                              "distributed and fp16 training.")
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    loss_fct = torch.nn.CrossEntropyLoss()

    results = {}
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        train_data = gen_tensor_dataset(train_features)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(
            train_data,
            sampler=train_sampler,
            batch_size=args.train_batch_size,
        )

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        latency_train = 0.0
        nb_tr_examples = 0
        model.train()
        tic_train = time.perf_counter()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if args.max_steps > 0 and global_step > args.max_steps:
                    break
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                logits = model(input_ids, segment_ids, input_mask)
                loss = loss_fct(
                    logits.view(-1, num_labels),
                    label_ids.view(-1),
                )
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up for BERT
                        # which FusedAdam doesn't do
                        scheduler.step()

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
        latency_train = time.perf_counter() - tic_train
        tr_loss = tr_loss / nb_tr_steps
        results.update({
            'global_step':
            global_step,
            'train:loss':
            tr_loss,
            'train:latency':
            latency_train,
            'train:num_samples_per_gpu':
            nb_tr_examples,
            'train:num_steps':
            nb_tr_steps,
            'train:throughput':
            get_world_size() * nb_tr_examples / latency_train,
        })
        if is_main_process() and not args.skip_checkpoint:
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(
                {"model": model_to_save.state_dict()},
                os.path.join(args.output_dir, modeling.WEIGHTS_NAME),
            )
            with open(
                    os.path.join(args.output_dir, modeling.CONFIG_NAME),
                    'w',
            ) as f:
                f.write(model_to_save.config.to_json_string())

    if (args.do_eval or args.do_predict) and is_main_process():
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features, label_map = convert_examples_to_features(
            eval_examples,
            processor.get_labels(),
            args.max_seq_length,
            tokenizer,
        )
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_data = gen_tensor_dataset(eval_features)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(
            eval_data,
            sampler=eval_sampler,
            batch_size=args.eval_batch_size,
        )

        model.eval()
        preds = None
        out_label_ids = None
        eval_loss = 0
        nb_eval_steps, nb_eval_examples = 0, 0
        cuda_events = [(torch.cuda.Event(enable_timing=True),
                        torch.cuda.Event(enable_timing=True))
                       for _ in range(len(eval_dataloader))]
        for i, (input_ids, input_mask, segment_ids, label_ids) in tqdm(
                enumerate(eval_dataloader),
                desc="Evaluating",
        ):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                cuda_events[i][0].record()
                logits = model(input_ids, segment_ids, input_mask)
                cuda_events[i][1].record()
                if args.do_eval:
                    eval_loss += loss_fct(
                        logits.view(-1, num_labels),
                        label_ids.view(-1),
                    ).mean().item()

            nb_eval_steps += 1
            nb_eval_examples += input_ids.size(0)
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = label_ids.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    label_ids.detach().cpu().numpy(),
                    axis=0,
                )
        torch.cuda.synchronize()
        eval_latencies = [
            event_start.elapsed_time(event_end)
            for event_start, event_end in cuda_events
        ]
        eval_latencies = list(sorted(eval_latencies))

        def infer_latency_sli(threshold):
            index = int(len(eval_latencies) * threshold) - 1
            index = min(max(index, 0), len(eval_latencies) - 1)
            return eval_latencies[index]

        eval_throughput = (args.eval_batch_size /
                           (np.mean(eval_latencies) / 1000))

        results.update({
            'eval:num_samples_per_gpu': nb_eval_examples,
            'eval:num_steps': nb_eval_steps,
            'infer:latency(ms):50%': infer_latency_sli(0.5),
            'infer:latency(ms):90%': infer_latency_sli(0.9),
            'infer:latency(ms):95%': infer_latency_sli(0.95),
            'infer:latency(ms):99%': infer_latency_sli(0.99),
            'infer:latency(ms):100%': infer_latency_sli(1.0),
            'infer:latency(ms):avg': np.mean(eval_latencies),
            'infer:latency(ms):std': np.std(eval_latencies),
            'infer:latency(ms):sum': np.sum(eval_latencies),
            'infer:throughput(samples/s):avg': eval_throughput,
        })
        preds = np.argmax(preds, axis=1)
        if args.do_predict:
            dump_predictions(
                os.path.join(args.output_dir, 'predictions.json'),
                label_map,
                preds,
                eval_examples,
            )
        if args.do_eval:
            results['eval:loss'] = eval_loss / nb_eval_steps
            eval_result = compute_metrics(args.task_name, preds, out_label_ids)
            results.update(eval_result)

    if is_main_process():
        logger.info("***** Results *****")
        for key in sorted(results.keys()):
            logger.info("  %s = %s", key, str(results[key]))
        with open(os.path.join(args.output_dir, "results.txt"), "w") as writer:
            json.dump(results, writer)
        dllogger_queries_from_results = {
            'exact_match': 'acc',
            'F1': 'f1',
            'e2e_train_time': 'train:latency',
            'training_sequences_per_second': 'train:throughput',
            'e2e_inference_time':
            ('infer:latency(ms):sum', lambda x: x / 1000),
            'inference_sequences_per_second':
            'infer:throughput(samples/s):avg',
        }
        for key, query in dllogger_queries_from_results.items():
            results_key, convert = (query if isinstance(query, tuple) else
                                    (query, lambda x: x))
            if results_key not in results:
                continue
            dllogger.log(
                step=tuple(),
                data={key: convert(results[results_key])},
            )
    dllogger.flush()
    return results
Ejemplo n.º 18
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    # model = Pretrain_model(args.num_layer, args.emb_dim, args.heads, args.num_message_passing, args.dropout_ratio)
    model = Pretrain_MolGNet(args.num_layer, args.emb_dim, args.heads, args.num_message_passing, args.dropout_ratio)
    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
            args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)
        
        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    optimizer = FusedLAMB(optimizer_grouped_parameters, 
                          lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer, 
                                       warmup=args.warmup_proportion, 
                                       total_steps=args.max_steps)
    if args.fp16:
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter]['step'] = global_step
                checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters          
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    criterion = PretrainingCriterion()

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Ejemplo n.º 19
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.0 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    # model = create_model(
    #     args.model,
    #     pretrained=False,
    #     num_classes=args.nb_classes,
    #     drop_rate=args.drop,
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=None,
    # )
    model = getattr(SwinTransformer, args.model)(num_classes=args.nb_classes,
                                                 drop_path_rate=args.drop_path)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr

    linear_scaled_warmup_lr = args.warmup_lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.warmup_lr = linear_scaled_warmup_lr

    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    # criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        lr_scheduler.step(epoch + 1)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            mixup_fn,
            set_training_mode=True  # keep in eval mode during finetuning
        )

        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 20
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    model = get_model(args)
    patch_size = model.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
    args.patch_size = patch_size

    # get dataset
    dataset_train = build_beit_pretraining_dataset(args)

    # prepare discrete vae
    d_vae = utils.create_d_vae(
        weight_path=args.discrete_vae_weight_path, d_vae_type=args.discrete_vae_type,
        device=device, image_size=args.second_input_size)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_rank = global_rank
        num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks

        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    model.to(device)
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Model = %s" % str(model_without_ddp))
    print('number of params:', n_parameters)

    total_batch_size = args.batch_size * utils.get_world_size()
    print("LR = %.8f" % args.lr)
    print("Batch size = %d" % total_batch_size)
    print("Number of training steps = %d" % num_training_steps_per_epoch)
    print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    optimizer = create_optimizer(
        args, model_without_ddp)
    loss_scaler = NativeScaler()

    print("Use step level LR & WD scheduler!")
    lr_schedule_values = utils.cosine_scheduler(
        args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
        warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
    )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = utils.cosine_scheduler(
        args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
    print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))

    utils.auto_load_model(
        args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        if log_writer is not None:
            log_writer.set_step(epoch * num_training_steps_per_epoch)
        train_stats = train_one_epoch(
            model, d_vae, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, log_writer=log_writer,
            start_steps=epoch * num_training_steps_per_epoch,
            lr_schedule_values=lr_schedule_values,
            wd_schedule_values=wd_schedule_values,
        )
        if args.output_dir:
            if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                utils.save_model(
                    args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                    loss_scaler=loss_scaler, epoch=epoch)

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch, 'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 21
0
                            del train_dataloader
                            # thread.join()
                            return args, final_loss, train_time_raw, global_step

                del train_dataloader
                # thread.join()
                # Make sure pool has finished and switch train_dataloader
                # NOTE: Will block until complete
                train_dataloader, data_file = dataset_future.result(timeout=None)

            epoch += 1


if __name__ == "__main__":

    now = time.time()
    args, final_loss, train_time_raw, global_step = main()
    gpu_count = args.n_gpu
    global_step += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
    if args.resume_step == -1:
        args.resume_step = 0
    if torch.distributed.is_initialized():
        gpu_count = get_world_size()
    if is_main_process():
        e2e_time = time.time() - now
        training_perf = args.train_batch_size * args.gradient_accumulation_steps * gpu_count\
                        * (global_step - args.resume_step + skipped_steps) / train_time_raw
        dllogger.log(step=tuple(), data={"e2e_train_time": e2e_time, "training_sequences_per_second": training_perf,
                                         "final_loss": final_loss, "raw_train_time": train_time_raw })
    dllogger.flush()
Ejemplo n.º 22
0
def main():
    # Parse essential args
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        required=True,
                        help="Location of data files (model weights, etc).")
    parser.add_argument("--model_name",
                        required=True,
                        help="The name of the model being fine-tuned.")
    parser.add_argument("--pretrain_tfrecords", type=str)

    parser.add_argument("--seed", type=int)
    parser.add_argument("--num_train_steps", type=int)
    parser.add_argument("--num_warmup_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--train_batch_size", type=int)
    parser.add_argument("--max_seq_length", type=int)

    parser.add_argument("--mask_prob", type=float)
    parser.add_argument("--disc_weight", type=float)
    parser.add_argument("--generator_hidden_size", type=float)

    parser.add_argument("--save_checkpoints_steps", type=int)
    parser.add_argument("--keep_checkpoint_max", type=int)
    parser.add_argument("--restore_checkpoint", action='store_true')

    parser.add_argument("--optimizer",
                        default="adam",
                        type=str,
                        help="adam or lamb")

    args = parser.parse_args()
    config = PretrainingConfig(**args.__dict__)

    # Set up tensorflow
    hvd.init()
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()],
                                                   'GPU')
    tf.config.optimizer.set_jit(config.xla)
    tf.config.optimizer.set_experimental_options(
        {"auto_mixed_precision": config.amp})
    tf.random.set_seed(config.seed)

    # Set up config
    if config.do_train == config.do_eval:
        raise ValueError(
            "Exactly one of `do_train` or `do_eval` must be True.")
    if config.debug and config.do_train:
        utils.rmkdir(config.model_dir)
    utils.heading("Config:")
    log_config(config)

    # Save pretrain configs
    pretrain_config_json = os.path.join(config.checkpoints_dir,
                                        'pretrain_config.json')
    if is_main_process():
        utils.write_json(config.__dict__, pretrain_config_json)
        log("Configuration saved in {}".format(pretrain_config_json))

    # Set up model
    model = PretrainingModel(config)

    # Set up metrics
    perf_metrics = dict()
    perf_metrics["train_perf"] = tf.keras.metrics.Mean(name="train_perf")

    eval_metrics = dict()
    eval_metrics["total_loss"] = tf.keras.metrics.Mean(name="total_loss")
    eval_metrics["masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
        name="masked_lm_accuracy")
    eval_metrics["masked_lm_loss"] = tf.keras.metrics.Mean(
        name="masked_lm_loss")
    if config.electra_objective:
        eval_metrics["sampled_masked_lm_accuracy"] = tf.keras.metrics.Accuracy(
            name="sampled_masked_lm_accuracy")
        if config.disc_weight > 0:
            eval_metrics["disc_loss"] = tf.keras.metrics.Mean(name="disc_loss")
            eval_metrics["disc_auc"] = tf.keras.metrics.AUC(name="disc_auc")
            eval_metrics["disc_accuracy"] = tf.keras.metrics.Accuracy(
                name="disc_accuracy")
            eval_metrics["disc_precision"] = tf.keras.metrics.Accuracy(
                name="disc_precision")
            eval_metrics["disc_recall"] = tf.keras.metrics.Accuracy(
                name="disc_recall")

    # Set up tensorboard
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = os.path.join(
        config.log_dir, current_time,
        'train_' + str(get_rank()) + '_of_' + str(get_world_size()))
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)

    # Set up dataset
    dataset = pretrain_utils.get_dataset(config,
                                         config.train_batch_size,
                                         world_size=get_world_size(),
                                         rank=get_rank())
    train_iterator = iter(dataset)

    # Set up optimizer
    optimizer = create_optimizer(init_lr=config.learning_rate,
                                 num_train_steps=config.num_train_steps,
                                 num_warmup_steps=config.num_warmup_steps,
                                 weight_decay_rate=config.weight_decay_rate,
                                 optimizer=config.optimizer)
    if config.amp:
        optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
            optimizer, "dynamic")

    if config.do_train:
        # Set up checkpoint manager
        checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
                                         optimizer=optimizer,
                                         model=model)
        manager = tf.train.CheckpointManager(
            checkpoint,
            config.checkpoints_dir,
            max_to_keep=config.keep_checkpoint_max)
        iter_checkpoint = tf.train.Checkpoint(train_iterator=train_iterator)
        iter_manager = tf.train.CheckpointManager(
            iter_checkpoint,
            os.path.join(config.checkpoints_dir,
                         'iter_ckpt_rank_' + '{:02}'.format(get_rank())),
            checkpoint_name='iter_ckpt_rank_' + '{:02}'.format(get_rank()),
            max_to_keep=config.keep_checkpoint_max)
        if config.restore_checkpoint and manager.latest_checkpoint:
            checkpoint.restore(manager.latest_checkpoint)
            log(" ** Restored model checkpoint from {}".format(
                manager.latest_checkpoint))
            if iter_manager.latest_checkpoint:
                iter_checkpoint.restore(iter_manager.latest_checkpoint)
                log(" ** Restored iterator checkpoint from {}".format(
                    iter_manager.latest_checkpoint),
                    all_rank=True)
        else:
            log(" ** Initializing from scratch.")

        utils.heading("Running training")
        train_start, start_step = time.time(), int(checkpoint.step) - 1
        while int(checkpoint.step) <= config.num_train_steps:
            step = int(checkpoint.step)
            features = next(train_iterator)
            iter_start = time.time()

            # if step == 200: tf.profiler.experimental.start(logdir=train_log_dir)
            total_loss, eval_fn_inputs = train_one_step(
                config, model, optimizer, features, step <= 1)
            # if step == 300: tf.profiler.experimental.stop()

            perf_metrics["train_perf"].update_state(config.train_batch_size *
                                                    get_world_size() /
                                                    (time.time() - iter_start))
            eval_metrics["total_loss"].update_state(values=total_loss)
            metric_fn(config, eval_metrics, eval_fn_inputs)

            if step % 100 == 0:
                log('Step:{:6d}, Loss:{:10.6f}, Gen_loss:{:10.6f}, Disc_loss:{:10.6f}, Gen_acc:{:6.2f}, '
                    'Disc_acc:{:6.2f}, Perf:{:4.0f}, Elapsed: {}, ETA: {}, '.
                    format(
                        step, total_loss,
                        eval_metrics["masked_lm_loss"].result().numpy(),
                        eval_metrics["disc_loss"].result().numpy(),
                        eval_metrics["masked_lm_accuracy"].result().numpy() *
                        100,
                        eval_metrics["disc_accuracy"].result().numpy() * 100,
                        perf_metrics["train_perf"].result().numpy(),
                        utils.get_readable_time(time.time() - train_start),
                        utils.get_readable_time(
                            (time.time() - train_start) / (step - start_step) *
                            (config.num_train_steps - step))),
                    all_rank=True)

                with train_summary_writer.as_default():
                    for key, m in eval_metrics.items():
                        tf.summary.scalar(key, m.result(), step=step)

                for m in eval_metrics.values():
                    m.reset_states()

            checkpoint.step.assign_add(1)
            if step % config.save_checkpoints_steps == 0:
                if is_main_process():
                    save_path = manager.save()
                    log(" ** Saved model checkpoint for step {}: {}".format(
                        step, save_path))
                iter_save_path = iter_manager.save()
                log(" ** Saved iterator checkpoint for step {}: {}".format(
                    step, iter_save_path),
                    all_rank=True)

    if config.do_eval:
        pass
Ejemplo n.º 23
0
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)

    if args.disable_weight_tying:
        import torch.nn as nn
        print ("WARNING!!!!!!! Disabling weight tying for this run")
        print ("BEFORE ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        model.cls.predictions.decoder.weight = torch.nn.Parameter(model.cls.predictions.decoder.weight.clone().detach())
        print ("AFTER ", model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight)
        assert (model.cls.predictions.decoder.weight is model.bert.embeddings.word_embeddings.weight) == False

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")]
            args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names])

        global_step = args.resume_step if not args.init_checkpoint else 0

        if not args.init_checkpoint:
            checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu")
        else:
            checkpoint = torch.load(args.init_checkpoint, map_location="cpu")

        model.load_state_dict(checkpoint['model'], strict=False)
        
        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    optimizer = FusedAdam(optimizer_grouped_parameters,
                          lr=args.learning_rate)
    lr_scheduler = PolyWarmUpScheduler(optimizer, 
                                       warmup=args.warmup_proportion, 
                                       total_steps=args.max_steps,
                                       degree=1)
    if args.fp16:

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    model.checkpoint_activations(args.checkpoint_activations)

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint['optimizer']['state'].keys())
            #Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint['optimizer']['state'][key]['step'] = global_step
            for iter, item in enumerate(checkpoint['optimizer']['param_groups']):
                checkpoint['optimizer']['param_groups'][iter]['step'] = global_step
                checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps
                checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion
                checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate
        optimizer.load_state_dict(checkpoint['optimizer'])  # , strict=False)

        # Restore AMP master parameters          
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
                param.data.copy_(saved_param.data)

    if args.local_rank != -1:
        if not args.allreduce_post_accumulation:
            model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
        else:
            flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)


    if args.disable_weight_tying:
       # Sanity Check that new param is in optimizer
       print ("SANITY CHECK OPTIMIZER: ", id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']])
       assert id(model.module.cls.predictions.decoder.weight) in [id(g) for g in optimizer.param_groups[0]['params']]

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Ejemplo n.º 24
0
def train_dino(args):
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # ============ preparing data ... ============
    transform = DataAugmentationDINO(
        args.global_crops_scale,
        args.local_crops_scale,
        args.local_crops_number,
    )
    #dataset = datasets.ImageFolder(args.data_path, transform=transform)
    from sen12ms import get_transform
    dataset = AllSen12MSDataset(args.data_path,
                                "train",
                                transform=transform,
                                tansform_coord=None,
                                classes=None,
                                seasons=None,
                                split_by_region=True,
                                download=False)

    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    print(f"Data loaded: there are {len(dataset)} images.")

    # ============ building student and teacher networks ... ============
    # if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            drop_path_rate=0.1,  # stochastic depth
        )
        teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
        embed_dim = student.embed_dim

        student = utils.replace_input_layer(student, inchannels=13)
        teacher = utils.replace_input_layer(teacher, inchannels=13)

    # otherwise, we check if the architecture is in torchvision models
    elif args.arch in torchvision_models.__dict__.keys():
        student = torchvision_models.__dict__[args.arch]()
        teacher = torchvision_models.__dict__[args.arch]()
        embed_dim = student.fc.weight.shape[1]
    else:
        print(f"Unknow architecture: {args.arch}")

    # multi-crop wrapper handles forward with inputs of different resolutions
    student = utils.MultiCropWrapper(
        student,
        DINOHead(
            embed_dim,
            args.out_dim,
            use_bn=args.use_bn_in_head,
            norm_last_layer=args.norm_last_layer,
        ))
    teacher = utils.MultiCropWrapper(
        teacher,
        DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
    )
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    # synchronize batch norms (if any)
    if utils.has_batchnorms(student):
        student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
        teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)

        # we need DDP wrapper to have synchro batch norms working...
        teacher = nn.parallel.DistributedDataParallel(teacher,
                                                      device_ids=[args.gpu])
        teacher_without_ddp = teacher.module
    else:
        # teacher_without_ddp and teacher are the same thing
        teacher_without_ddp = teacher
    student = nn.parallel.DistributedDataParallel(student,
                                                  device_ids=[args.gpu])
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict())
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    dino_loss = DINOLoss(
        args.out_dim,
        args.local_crops_number +
        2,  # total number of crops = 2 global crops + local_crops_number
        args.warmup_teacher_temp,
        args.teacher_temp,
        args.warmup_teacher_temp_epochs,
        args.epochs,
    ).cuda()

    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    if args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params_groups, lr=0,
                                    momentum=0.9)  # lr is set by scheduler
    elif args.optimizer == "lars":
        optimizer = utils.LARS(
            params_groups)  # to use with convnet and large batches
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) /
        256.,  # linear scaling rule
        args.min_lr,
        args.epochs,
        len(data_loader),
        warmup_epochs=args.warmup_epochs,
    )
    wd_schedule = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs,
        len(data_loader),
    )
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
                                               args.epochs, len(data_loader))
    print(f"Loss, optimizer and schedulers ready.")

    # ============ optionally resume training ... ============
    to_restore = {"epoch": 0}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        fp16_scaler=fp16_scaler,
        dino_loss=dino_loss,
    )
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting DINO training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ============ training one epoch of DINO ... ============
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp,
                                      dino_loss, data_loader, optimizer,
                                      lr_schedule, wd_schedule,
                                      momentum_schedule, epoch, fp16_scaler,
                                      args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'dino_loss': dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        utils.save_on_master(save_dict,
                             os.path.join(args.output_dir, 'checkpoint.pth'))
        if args.saveckp_freq and epoch % args.saveckp_freq == 0:
            utils.save_on_master(
                save_dict,
                os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()}, 'epoch': epoch
        }
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 25
0
def main(args):

    utils.init_distributed_mode(args)
    update_config_from_file(args.cfg)

    print(args)
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=int(
                                                      2 * args.batch_size),
                                                  sampler=sampler_val,
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating SuperVisionTransformer")
    print(cfg)
    model = Vision_TransformerSuper(
        img_size=args.input_size,
        patch_size=args.patch_size,
        embed_dim=cfg.SUPERNET.EMBED_DIM,
        depth=cfg.SUPERNET.DEPTH,
        num_heads=cfg.SUPERNET.NUM_HEADS,
        mlp_ratio=cfg.SUPERNET.MLP_RATIO,
        qkv_bias=True,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        gp=args.gp,
        num_classes=args.nb_classes,
        max_relative_position=args.max_relative_position,
        relative_position=args.relative_position,
        change_qkv=args.change_qkv,
        abs_pos=not args.no_abs_pos)

    choices = {
        'num_heads': cfg.SEARCH_SPACE.NUM_HEADS,
        'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
        'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM,
        'depth': cfg.SEARCH_SPACE.DEPTH
    }

    model.to(device)
    if args.teacher_model:
        teacher_model = create_model(
            args.teacher_model,
            pretrained=True,
            num_classes=args.nb_classes,
        )
        teacher_model.to(device)
        teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        teacher_model = None
        teacher_loss = None

    model_ema = None

    model_without_ddp = model
    if args.distributed:

        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()
    lr_scheduler, _ = create_scheduler(args, optimizer)

    # criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    output_dir = Path(args.output_dir)

    if not output_dir.exists():
        output_dir.mkdir(parents=True)
    # save config for later experiments
    with open(output_dir / "config.yaml", 'w') as f:
        f.write(args_text)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])

    retrain_config = None
    if args.mode == 'retrain' and "RETRAIN" in cfg:
        retrain_config = {
            'layer_num': cfg.RETRAIN.DEPTH,
            'embed_dim': [cfg.RETRAIN.EMBED_DIM] * cfg.RETRAIN.DEPTH,
            'num_heads': cfg.RETRAIN.NUM_HEADS,
            'mlp_ratio': cfg.RETRAIN.MLP_RATIO
        }
    if args.eval:
        print(retrain_config)
        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print("Start training")
    start_time = time.time()
    max_accuracy = 0.0

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            amp=args.amp,
            teacher_model=teacher_model,
            teach_loss=teacher_loss,
            choices=choices,
            mode=args.mode,
            retrain_config=retrain_config,
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        # 'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    },
                    checkpoint_path)

        test_stats = evaluate(data_loader_val,
                              model,
                              device,
                              amp=args.amp,
                              choices=choices,
                              mode=args.mode,
                              retrain_config=retrain_config)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {
            **{f'train_{k}': v
               for k, v in train_stats.items()},
            **{f'test_{k}': v
               for k, v in test_stats.items()}, 'epoch': epoch,
            'n_parameters': n_parameters
        }

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 26
0
def main(args):
    utils.init_distributed_mode(args)

    # disable any harsh augmentation in case of Self-supervise training
    if args.training_mode == 'SSL':
        print("NOTE: Smoothing, Mixup, CutMix, and AutoAugment will be disabled in case of Self-supervise training")
        args.smoothing = args.reprob = args.reprob = args.recount = args.mixup = args.cutmix = 0.0
        args.aa = ''

        if args.SiT_LinearEvaluation == 1:
            print("Warning: Linear Evaluation should be set to 0 during SSL training - changing SiT_LinearEvaluation to 0")
            args.SiT_LinearEvaluation = 0
        
    utils.print_args(args)

    device = torch.device(args.device)
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    print("Loading dataset ....")
    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)   
    dataset_val, _ = build_dataset(is_train=False, args=args)
    

    num_tasks = utils.get_world_size()
    global_rank = utils.get_rank()
    if args.repeated_aug:
        sampler_train = RASampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    else:
        sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)
    
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)


    data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,
        batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=True, collate_fn=collate_fn)

    data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers,
        pin_memory=args.pin_mem, drop_last=False, collate_fn=collate_fn)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model, pretrained=False, num_classes=args.nb_classes,
        drop_rate=args.drop, drop_path_rate=args.drop_path, representation_size=args.representation_size,
        drop_block_rate=None, training_mode=args.training_mode)

    if args.finetune:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in ['rot_head.weight', 'rot_head.bias', 'contrastive_head.weight', 'contrastive_head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        new_size = int(num_patches ** 0.5)
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(
            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    # Freeze the backbone in case of linear evaluation
    if args.SiT_LinearEvaluation == 1:
        requires_grad(model, False)
        
        model.rot_head.weight.requires_grad = True
        model.rot_head.bias.requires_grad = True
        
        model.contrastive_head.weight.requires_grad = True
        model.contrastive_head.bias.requires_grad = True
        
        if args.representation_size is not None:
            model.pre_logits_rot.fc.weight.requires_grad = True
            model.pre_logits_rot.fc.bias.requires_grad = True
            
            model.pre_logits_contrastive.fc.weight.requires_grad = True
            model.pre_logits_contrastive.fc.bias.requires_grad = True            


    model_ema = None
    if args.model_ema:
        model_ema = ModelEma(model, decay=args.model_ema_decay,
            device='cpu' if args.model_ema_force_cpu else '', resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    if args.training_mode == 'SSL':
        criterion = MTL_loss(args.device, args.batch_size)
    elif args.training_mode == 'finetune' and args.mixup > 0.:
        criterion = SoftTargetCrossEntropy()
    else:
        criterion = torch.nn.CrossEntropyLoss()



    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate_SSL(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        if args.training_mode == 'SSL':
            train_stats = train_SSL(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
        else:
            train_stats = train_finetune(
                model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler,
                args.clip_grad, model_ema, mixup_fn)
            
        lr_scheduler.step(epoch)
            
        if epoch%args.validate_every == 0:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint.pth']
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)
    
            if args.training_mode == 'SSL':
                test_stats = evaluate_SSL(data_loader_val, model, device, epoch, args.output_dir)
            else:
                test_stats = evaluate_finetune(data_loader_val, model, device)

                print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
                max_accuracy = max(max_accuracy, test_stats["acc1"])
                print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 27
0
def main():
    args = parse_args()

    hvd.init()
    set_affinity(hvd.local_rank())

    if is_main_process():
        log("Running total processes: {}".format(get_world_size()))
    log("Starting process: {}".format(get_rank()))

    if is_main_process():
        dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                                           filename=args.json_summary),
                                dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
    else:
        dllogger.init(backends=[])

    tf.random.set_seed(args.seed)
    dllogger.log(step="PARAMETER", data={"SEED": args.seed})
    # script parameters
    BATCH_SIZE = args.train_batch_size
    EVAL_BATCH_SIZE = args.predict_batch_size
    USE_XLA = args.xla
    USE_AMP = args.amp
    EPOCHS = args.num_train_epochs

    if not args.do_train:
        EPOCHS = args.num_train_epochs = 1
        log("Since running inference only, setting args.num_train_epochs to 1")

    if not os.path.exists(args.output_dir) and is_main_process():
        os.makedirs(args.output_dir)

    # TensorFlow configuration
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    tf.config.optimizer.set_jit(USE_XLA)
    #tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
    
    if args.amp:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' % policy.compute_dtype)  # Compute dtype: float16
        print('Variable dtype: %s' % policy.variable_dtype)  # Variable dtype: float32

    if is_main_process():
        log("***** Loading tokenizer and model *****")
    # Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
    electra_model = args.electra_model
    config = ElectraConfig.from_pretrained(electra_model, cache_dir=args.cache_dir)
    config.update({"amp": args.amp})
    if args.vocab_file is None:
        tokenizer = ElectraTokenizer.from_pretrained(electra_model, cache_dir=args.cache_dir)
    else:
        tokenizer = ElectraTokenizer(
            vocab_file=args.vocab_file,
            do_lower_case=args.do_lower_case)

    model = TFElectraForQuestionAnswering.from_pretrained(electra_model, config=config, cache_dir=args.cache_dir, args=args)

    if is_main_process():
        log("***** Loading dataset *****")
    # Load data
    processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
    train_examples = processor.get_train_examples(args.data_dir) if args.do_train else None
    dev_examples = processor.get_dev_examples(args.data_dir) if args.do_predict else None

    if is_main_process():
        log("***** Loading features *****")
    # Load cached features
    squad_version = '2.0' if args.version_2_with_negative else '1.1'
    if args.cache_dir is None:
        args.cache_dir = args.data_dir
    cached_train_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_train-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)
    cached_dev_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_dev-v{4}.json_{1}_{2}_{3}'.format(
        electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
        str(args.max_query_length), squad_version)

    try:
        with open(cached_train_features_file, "rb") as reader:
            train_features = pickle.load(reader) if args.do_train else []
        with open(cached_dev_features_file, "rb") as reader:
            dev_features = pickle.load(reader) if args.do_predict else []
    except:
        train_features = (  # TODO: (yy) do on rank 0?
            squad_convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True,
                return_dataset="",
            )
            if args.do_train
            else []
        )
        dev_features = (
            squad_convert_examples_to_features(
                examples=dev_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=False,
                return_dataset="",
            )
            if args.do_predict
            else []
        )
        # Dump Cached features
        if not args.skip_cache and is_main_process():
            if args.do_train:
                log("***** Building Cache Files: {} *****".format(cached_train_features_file))
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
            if args.do_predict:
                log("***** Building Cache Files: {} *****".format(cached_dev_features_file))
                with open(cached_dev_features_file, "wb") as writer:
                    pickle.dump(dev_features, writer)

    len_train_features = len(train_features)
    total_train_steps = int((len_train_features * EPOCHS / BATCH_SIZE) / get_world_size()) + 1
    train_steps_per_epoch = int((len_train_features / BATCH_SIZE) / get_world_size()) + 1
    len_dev_features = len(dev_features)
    total_dev_steps = int((len_dev_features / EVAL_BATCH_SIZE)) + 1

    train_dataset = get_dataset_from_features(train_features, BATCH_SIZE,
                                              v2=args.version_2_with_negative) if args.do_train else []
    dev_dataset = get_dataset_from_features(dev_features, EVAL_BATCH_SIZE, drop_remainder=False, ngpu=1, mode="dev",
                                            v2=args.version_2_with_negative) if args.do_predict else []

    opt = create_optimizer(init_lr=args.learning_rate, num_train_steps=total_train_steps,
                           num_warmup_steps=int(args.warmup_proportion * total_train_steps),
                           weight_decay_rate=args.weight_decay_rate,
                           layerwise_lr_decay=args.layerwise_lr_decay,
                           n_transformer_layers=model.num_hidden_layers)
    if USE_AMP:
        # loss scaling is currently required when using mixed precision
        opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")

    # Define loss function
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    loss_class = tf.keras.losses.BinaryCrossentropy(
        from_logits=True,
        name='binary_crossentropy'
    )
    metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
    model.compile(optimizer=opt, loss=loss, metrics=[metric])
    train_loss_results = []

    if args.do_train and is_main_process():
        log("***** Running training *****")
        log("  Num examples = ", len_train_features)
        log("  Num Epochs = ", args.num_train_epochs)
        log("  Instantaneous batch size per GPU = ", args.train_batch_size)
        log(
            "  Total train batch size (w. parallel, distributed & accumulation) = ",
            args.train_batch_size
            * get_world_size(),
        )
        log("  Total optimization steps =", total_train_steps)

    total_train_time = 0
    latency = []
    for epoch in range(EPOCHS):
        if args.do_train:
            epoch_loss_avg = tf.keras.metrics.Mean()
            epoch_perf_avg = tf.keras.metrics.Mean()
            epoch_start = time.time()

            epoch_iterator = tqdm(train_dataset, total=train_steps_per_epoch, desc="Iteration", mininterval=5,
                                  disable=not is_main_process())
            for iter, inputs in enumerate(epoch_iterator):
                # breaking criterion if max_steps if > 1
                if args.max_steps > 0 and (epoch * train_steps_per_epoch + iter) > args.max_steps:
                    break
                iter_start = time.time()
                # Optimize the model
                loss_value = train_step(model, inputs, loss, USE_AMP, opt, (iter == 0 and epoch == 0),
                                        v2=args.version_2_with_negative, loss_class=loss_class, fp16=USE_AMP)
                epoch_perf_avg.update_state(1. * BATCH_SIZE / (time.time() - iter_start))
                if iter % args.log_freq == 0:
                    if is_main_process():
                        log("\nEpoch: {:03d}, Step:{:6d}, Loss:{:12.8f}, Perf:{:5.0f}, loss_scale:{}, opt_step:{}".format(epoch, iter, loss_value,
                                                                                              epoch_perf_avg.result() * get_world_size(), opt.loss_scale if config.amp else 1,
                                                                                              int(opt.iterations)))
                    dllogger.log(step=(epoch, iter,), data={"step_loss": float(loss_value.numpy()),
                                                            "train_perf": float( epoch_perf_avg.result().numpy() * get_world_size())})

                # Track progress
                epoch_loss_avg.update_state(loss_value)  # Add current batch loss

            # End epoch
            train_loss_results.append(epoch_loss_avg.result())
            total_train_time += float(time.time() - epoch_start)
            # Summarize and save checkpoint at the end of each epoch
            if is_main_process():

                dllogger.log(step=tuple(), data={"e2e_train_time": total_train_time,
                                                 "training_sequences_per_second": float(
                                                     epoch_perf_avg.result().numpy() * get_world_size()),
                                                 "final_loss": float(epoch_loss_avg.result().numpy())})

            if not args.skip_checkpoint:
                if args.ci:
                    checkpoint_name = "{}/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.output_dir, args.version_2_with_negative, epoch + 1)
                else:
                    checkpoint_name = "checkpoints/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.version_2_with_negative, epoch + 1)
                if is_main_process():
                    model.save_weights(checkpoint_name)


        if args.do_predict and (args.evaluate_during_training or epoch == args.num_train_epochs - 1):
            if not args.do_train:
                log("***** Loading checkpoint: {} *****".format(args.init_checkpoint))
                model.load_weights(args.init_checkpoint).expect_partial()

            current_feature_id = 0
            all_results = []
            if is_main_process():
                log("***** Running evaluation *****")
                log("  Num Batches = ", total_dev_steps)
                log("  Batch size = ", args.predict_batch_size)

            raw_infer_start = time.time()
            if is_main_process():
                infer_perf_avg = tf.keras.metrics.Mean()
                dev_iterator = tqdm(dev_dataset, total=total_dev_steps, desc="Iteration", mininterval=5,
                                    disable=not is_main_process())
                for input_ids, input_mask, segment_ids, start_positions, end_positions, cls_index, p_mask, is_impossible in dev_iterator:
                    # training=False is needed only if there are layers with different
                    # behavior during training versus inference (e.g. Dropout).

                    iter_start = time.time()

                    if not args.joint_head:
                        batch_start_logits, batch_end_logits = infer_step(model, input_ids,
                                                                          attention_mask=input_mask,
                                                                          token_type_ids=segment_ids,
                                                                          )[:2]
                        #Synchronize with GPU to compute time
                        _ = batch_start_logits.numpy()
                                                            
                    else:
                        
                        outputs = infer_step(model, input_ids,
                                             attention_mask=input_mask,
                                             token_type_ids=segment_ids,
                                             cls_index=cls_index,
                                             p_mask=p_mask,
                                             )
                        #Synchronize with GPU to compute time
                        _ = outputs[0].numpy()

                    infer_time = (time.time() - iter_start)
                    infer_perf_avg.update_state(1. * EVAL_BATCH_SIZE / infer_time)
                    latency.append(infer_time)

                    for iter_ in range(input_ids.shape[0]):

                        if not args.joint_head:
                            start_logits = batch_start_logits[iter_].numpy().tolist()
                            end_logits = batch_end_logits[iter_].numpy().tolist()
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            all_results.append(RawResult(unique_id=unique_id,
                                                         start_logits=start_logits,
                                                         end_logits=end_logits))
                        else:
                            dev_feature = dev_features[current_feature_id]
                            current_feature_id += 1
                            unique_id = int(dev_feature.unique_id)
                            output = [output[iter_].numpy().tolist() for output in outputs]

                            start_logits = output[0]
                            start_top_index = output[1]
                            end_logits = output[2]
                            end_top_index = output[3]
                            cls_logits = output[4]
                            result = SquadResult(
                                unique_id,
                                start_logits,
                                end_logits,
                                start_top_index=start_top_index,
                                end_top_index=end_top_index,
                                cls_logits=cls_logits,
                            )

                            all_results.append(result)

                # Compute and save predictions
                answers, nbest_answers = get_answers(dev_examples, dev_features, all_results, args)

                output_prediction_file = os.path.join(args.output_dir, "predictions.json")
                output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
                e2e_infer_time = time.time() - raw_infer_start
                # if args.version_2_with_negative:
                #     output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
                # else:
                #     output_null_log_odds_file = None
                with open(output_prediction_file, "w") as f:
                    f.write(json.dumps(answers, indent=4) + "\n")
                with open(output_nbest_file, "w") as f:
                    f.write(json.dumps(nbest_answers, indent=4) + "\n")

                if args.do_eval:
                    if args.version_2_with_negative:
                        dev_file = "dev-v2.0.json"
                    else:
                        dev_file = "dev-v1.1.json"

                    eval_out = subprocess.check_output([sys.executable, args.eval_script,
                                                        args.data_dir + "/" + dev_file, output_prediction_file])
                    log(eval_out.decode('UTF-8'))
                    scores = str(eval_out).strip()
                    exact_match = float(scores.split(":")[1].split(",")[0])
                    if args.version_2_with_negative:
                        f1 = float(scores.split(":")[2].split(",")[0])
                    else:
                        f1 = float(scores.split(":")[2].split("}")[0])

                    log("Epoch: {:03d} Results: {}".format(epoch, eval_out.decode('UTF-8')))
                    log("**EVAL SUMMARY** - Epoch: {:03d},  EM: {:6.3f}, F1: {:6.3f}, Infer_Perf: {:4.0f} seq/s"
                          .format(epoch, exact_match, f1, infer_perf_avg.result()))

                latency_all = sorted(latency)[:-2]
                log(
                    "**LATENCY SUMMARY** - Epoch: {:03d},  Ave: {:6.3f} ms, 90%: {:6.3f} ms, 95%: {:6.3f} ms, 99%: {:6.3f} ms"
                    .format(epoch, sum(latency_all) / len(latency_all) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.9)]) / int(len(latency_all) * 0.9) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.95)]) / int(len(latency_all) * 0.95) * 1000,
                            sum(latency_all[:int(len(latency_all) * 0.99)]) / int(len(latency_all) * 0.99) * 1000,
                            ))
                dllogger.log(step=tuple(),
                             data={"inference_sequences_per_second": float(infer_perf_avg.result().numpy()), 
                                   "e2e_inference_time": e2e_infer_time})

    if is_main_process() and args.do_train and args.do_eval:
        log(
            "**RESULTS SUMMARY** - EM: {:6.3f}, F1: {:6.3f}, Train_Time: {:4.0f} s, Train_Perf: {:4.0f} seq/s, Infer_Perf: {:4.0f} seq/s"
            .format(exact_match, f1, total_train_time, epoch_perf_avg.result() * get_world_size(),
                    infer_perf_avg.result()))
        dllogger.log(step=tuple(), data={"exact_match": exact_match, "F1": f1})
Ejemplo n.º 28
0
def main(args, ds_init):
    utils.init_distributed_mode(args)

    if ds_init is not None:
        utils.create_ds_config(args)

    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    if args.disable_eval_during_finetuning:
        dataset_val = None
    else:
        dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=True)
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    if dataset_val is not None:
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val,
            sampler=sampler_val,
            batch_size=int(1.5 * args.batch_size),
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False)
    else:
        data_loader_val = None

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        print("Mixup is activated!")
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        attn_drop_rate=args.attn_drop_rate,
        drop_block_rate=None,
        use_mean_pooling=args.use_mean_pooling,
        init_scale=args.init_scale,
        use_rel_pos_bias=args.rel_pos_bias,
        use_abs_pos_emb=args.abs_pos_emb,
        init_values=args.layer_scale_init_value,
    )

    patch_size = model.patch_embed.patch_size
    print("Patch size = %s" % str(patch_size))
    args.window_size = (args.input_size // patch_size[0],
                        args.input_size // patch_size[1])
    args.patch_size = patch_size

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        print("Load ckpt from %s" % args.finetune)
        checkpoint_model = None
        for model_key in args.model_key.split('|'):
            if model_key in checkpoint:
                checkpoint_model = checkpoint[model_key]
                print("Load state_dict by model_key = %s" % model_key)
                break
        if checkpoint_model is None:
            checkpoint_model = checkpoint
        state_dict = model.state_dict()
        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
            print(
                "Expand the shared relative position embedding to each transformer block. "
            )
            num_layers = model.get_num_layers()
            rel_pos_bias = checkpoint_model[
                "rel_pos_bias.relative_position_bias_table"]
            for i in range(num_layers):
                checkpoint_model["blocks.%d.attn.relative_position_bias_table"
                                 % i] = rel_pos_bias.clone()

            checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")

        all_keys = list(checkpoint_model.keys())
        for key in all_keys:
            if "relative_position_index" in key:
                checkpoint_model.pop(key)

            if "relative_position_bias_table" in key:
                rel_pos_bias = checkpoint_model[key]
                src_num_pos, num_attn_heads = rel_pos_bias.size()
                dst_num_pos, _ = model.state_dict()[key].size()
                dst_patch_shape = model.patch_embed.patch_shape
                if dst_patch_shape[0] != dst_patch_shape[1]:
                    raise NotImplementedError()
                num_extra_tokens = dst_num_pos - (
                    dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
                src_size = int((src_num_pos - num_extra_tokens)**0.5)
                dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
                if src_size != dst_size:
                    print("Position interpolate for %s from %dx%d to %dx%d" %
                          (key, src_size, src_size, dst_size, dst_size))
                    extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                    rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                    def geometric_progression(a, r, n):
                        return a * (1.0 - r**n) / (1.0 - r)

                    left, right = 1.01, 1.5
                    while right - left > 1e-6:
                        q = (left + right) / 2.0
                        gp = geometric_progression(1, q, src_size // 2)
                        if gp > dst_size // 2:
                            right = q
                        else:
                            left = q

                    # if q > 1.090307:
                    #     q = 1.090307

                    dis = []
                    cur = 1
                    for i in range(src_size // 2):
                        dis.append(cur)
                        cur += q**(i + 1)

                    r_ids = [-_ for _ in reversed(dis)]

                    x = r_ids + [0] + dis
                    y = r_ids + [0] + dis

                    t = dst_size // 2.0
                    dx = np.arange(-t, t + 0.1, 1.0)
                    dy = np.arange(-t, t + 0.1, 1.0)

                    print("Original positions = %s" % str(x))
                    print("Target positions = %s" % str(dx))

                    all_rel_pos_bias = []

                    for i in range(num_attn_heads):
                        z = rel_pos_bias[:, i].view(src_size,
                                                    src_size).float().numpy()
                        f = interpolate.interp2d(x, y, z, kind='cubic')
                        all_rel_pos_bias.append(
                            torch.Tensor(f(dx, dy)).contiguous().view(
                                -1, 1).to(rel_pos_bias.device))

                    rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)

                    new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens),
                                                 dim=0)
                    checkpoint_model[key] = new_rel_pos_bias

        # interpolate position embedding
        if 'pos_embed' in checkpoint_model:
            pos_embed_checkpoint = checkpoint_model['pos_embed']
            embedding_size = pos_embed_checkpoint.shape[-1]
            num_patches = model.patch_embed.num_patches
            num_extra_tokens = model.pos_embed.shape[-2] - num_patches
            # height (== width) for the checkpoint position embedding
            orig_size = int(
                (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
            # height (== width) for the new position embedding
            new_size = int(num_patches**0.5)
            # class_token and dist_token are kept unchanged
            if orig_size != new_size:
                print("Position interpolate from %dx%d to %dx%d" %
                      (orig_size, orig_size, new_size, new_size))
                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                # only the position tokens are interpolated
                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                                embedding_size).permute(
                                                    0, 3, 1, 2)
                pos_tokens = torch.nn.functional.interpolate(
                    pos_tokens,
                    size=(new_size, new_size),
                    mode='bicubic',
                    align_corners=False)
                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                checkpoint_model['pos_embed'] = new_pos_embed

        utils.load_state_dict(model,
                              checkpoint_model,
                              prefix=args.model_prefix)
        # model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')
        print("Using EMA with decay = %.8f" % args.model_ema_decay)

    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

    print("Model = %s" % str(model_without_ddp))
    print('number of params:', n_parameters)

    total_batch_size = args.batch_size * args.update_freq * utils.get_world_size(
    )
    num_training_steps_per_epoch = len(dataset_train) // total_batch_size
    print("LR = %.8f" % args.lr)
    print("Batch size = %d" % total_batch_size)
    print("Update frequent = %d" % args.update_freq)
    print("Number of training examples = %d" % len(dataset_train))
    print("Number of training training per epoch = %d" %
          num_training_steps_per_epoch)

    num_layers = model_without_ddp.get_num_layers()
    if args.layer_decay < 1.0:
        assigner = LayerDecayValueAssigner(
            list(args.layer_decay**(num_layers + 1 - i)
                 for i in range(num_layers + 2)))
    else:
        assigner = None

    if assigner is not None:
        print("Assigned values = %s" % str(assigner.values))

    skip_weight_decay_list = model.no_weight_decay()
    if args.disable_weight_decay_on_rel_pos_bias:
        for i in range(num_layers):
            skip_weight_decay_list.add(
                "blocks.%d.attn.relative_position_bias_table" % i)

    if args.enable_deepspeed:
        loss_scaler = None
        optimizer_params = get_parameter_groups(
            model, args.weight_decay, skip_weight_decay_list,
            assigner.get_layer_id if assigner is not None else None,
            assigner.get_scale if assigner is not None else None)
        model, optimizer, _, _ = ds_init(
            args=args,
            model=model,
            model_parameters=optimizer_params,
            dist_init_required=not args.distributed,
        )

        print("model.gradient_accumulation_steps() = %d" %
              model.gradient_accumulation_steps())
        assert model.gradient_accumulation_steps() == args.update_freq
    else:
        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
            model_without_ddp = model.module

        optimizer = create_optimizer(args,
                                     model_without_ddp,
                                     skip_list=skip_weight_decay_list,
                                     get_num_layer=assigner.get_layer_id
                                     if assigner is not None else None,
                                     get_layer_scale=assigner.get_scale
                                     if assigner is not None else None)
        loss_scaler = NativeScaler()

    print("Use step level LR scheduler!")
    lr_schedule_values = utils.cosine_scheduler(
        args.lr,
        args.min_lr,
        args.epochs,
        num_training_steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        warmup_steps=args.warmup_steps,
    )
    if args.weight_decay_end is None:
        args.weight_decay_end = args.weight_decay
    wd_schedule_values = utils.cosine_scheduler(args.weight_decay,
                                                args.weight_decay_end,
                                                args.epochs,
                                                num_training_steps_per_epoch)
    print("Max WD = %.7f, Min WD = %.7f" %
          (max(wd_schedule_values), min(wd_schedule_values)))

    if mixup_fn is not None:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    print("criterion = %s" % str(criterion))

    utils.auto_load_model(args=args,
                          model=model,
                          model_without_ddp=model_without_ddp,
                          optimizer=optimizer,
                          loss_scaler=loss_scaler,
                          model_ema=model_ema)

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        exit(0)

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        if log_writer is not None:
            log_writer.set_step(epoch * num_training_steps_per_epoch *
                                args.update_freq)
        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            log_writer=log_writer,
            start_steps=epoch * num_training_steps_per_epoch,
            lr_schedule_values=lr_schedule_values,
            wd_schedule_values=wd_schedule_values,
            num_training_steps_per_epoch=num_training_steps_per_epoch,
            update_freq=args.update_freq,
        )
        if args.output_dir and args.save_ckpt:
            if (epoch +
                    1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                utils.save_model(args=args,
                                 model=model,
                                 model_without_ddp=model_without_ddp,
                                 optimizer=optimizer,
                                 loss_scaler=loss_scaler,
                                 epoch=epoch,
                                 model_ema=model_ema)
        if data_loader_val is not None:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            if max_accuracy < test_stats["acc1"]:
                max_accuracy = test_stats["acc1"]
                if args.output_dir and args.save_ckpt:
                    utils.save_model(args=args,
                                     model=model,
                                     model_without_ddp=model_without_ddp,
                                     optimizer=optimizer,
                                     loss_scaler=loss_scaler,
                                     epoch="best",
                                     model_ema=model_ema)

            print(f'Max accuracy: {max_accuracy:.2f}%')
            if log_writer is not None:
                log_writer.update(test_acc1=test_stats['acc1'],
                                  head="perf",
                                  step=epoch)
                log_writer.update(test_acc5=test_stats['acc5'],
                                  head="perf",
                                  step=epoch)
                log_writer.update(test_loss=test_stats['loss'],
                                  head="perf",
                                  step=epoch)

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                # **{f'test_{k}': v for k, v in test_stats.items()},
                'epoch': epoch,
                'n_parameters': n_parameters
            }

        if args.output_dir and utils.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"),
                      mode="a",
                      encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 29
0
def main(args):
    utils.init_distributed_mode(args)

    print(args)

    if args.distillation_type != 'none' and args.finetune and not args.eval:
        raise NotImplementedError(
            "Finetuning with distillation not yet supported")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)

    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)

    if True:  # args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        if args.repeated_aug:
            sampler_train = RASampler(dataset_train,
                                      num_replicas=num_tasks,
                                      rank=global_rank,
                                      shuffle=True)
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print(
                    'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                    'This will slightly alter validation results as extra duplicate entries are added to achieve '
                    'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  sampler=sampler_val,
                                                  batch_size=int(
                                                      1.5 * args.batch_size),
                                                  num_workers=args.num_workers,
                                                  pin_memory=args.pin_mem,
                                                  drop_last=False)

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(mixup_alpha=args.mixup,
                         cutmix_alpha=args.cutmix,
                         cutmix_minmax=args.cutmix_minmax,
                         prob=args.mixup_prob,
                         switch_prob=args.mixup_switch_prob,
                         mode=args.mixup_mode,
                         label_smoothing=args.smoothing,
                         num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=None,
    )

    if args.finetune:
        if args.finetune.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.finetune,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.finetune, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
        for k in [
                'head.weight', 'head.bias', 'head_dist.weight',
                'head_dist.bias'
        ]:
            if k in checkpoint_model and checkpoint_model[
                    k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

        # interpolate position embedding
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int(
            (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches**0.5)
        # class_token and dist_token are kept unchanged
        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
        # only the position tokens are interpolated
        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                        embedding_size).permute(0, 3, 1, 2)
        pos_tokens = torch.nn.functional.interpolate(pos_tokens,
                                                     size=(new_size, new_size),
                                                     mode='bicubic',
                                                     align_corners=False)
        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
        checkpoint_model['pos_embed'] = new_pos_embed

        model.load_state_dict(checkpoint_model, strict=False)

    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume='')

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size(
    ) / 512.0
    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)
    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    criterion = LabelSmoothingCrossEntropy()

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    teacher_model = None
    if args.distillation_type != 'none':
        assert args.teacher_path, 'need to specify teacher-path when using distillation'
        print(f"Creating teacher model: {args.teacher_model}")
        teacher_model = create_model(
            args.teacher_model,
            pretrained=False,
            num_classes=args.nb_classes,
            global_pool='avg',
        )
        if args.teacher_path.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.teacher_path,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.teacher_path, map_location='cpu')
        teacher_model.load_state_dict(checkpoint['model'])
        teacher_model.to(device)
        teacher_model.eval()

    # wrap the criterion in our custom DistillationLoss, which
    # just dispatches to the original criterion if args.distillation_type is 'none'
    criterion = DistillationLoss(criterion, teacher_model,
                                 args.distillation_type,
                                 args.distillation_alpha,
                                 args.distillation_tau)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location='cpu',
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            args.start_epoch = checkpoint['epoch'] + 1
            if args.model_ema:
                utils._load_checkpoint_for_ema(model_ema,
                                               checkpoint['model_ema'])
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(
            f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
        )
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model,
            criterion,
            data_loader_train,
            optimizer,
            device,
            epoch,
            loss_scaler,
            args.clip_grad,
            model_ema,
            mixup_fn,
            set_training_mode=args.finetune ==
            ''  # keep in eval mode during finetuning
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / ('checkpoint_%04d.pth' % (epoch))]
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'model_ema': get_state_dict(model_ema),
                        'scaler': loss_scaler.state_dict(),
                        'args': args,
                    }, checkpoint_path)

        if not args.train_without_eval:
            test_stats = evaluate(data_loader_val, model, device)
            print(
                f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
            )
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(f'Max accuracy: {max_accuracy:.2f}%')

            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()},
                **{f'test_{k}': v
                   for k, v in test_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        else:
            log_stats = {
                **{f'train_{k}': v
                   for k, v in train_stats.items()}, 'epoch': epoch,
                'n_parameters': n_parameters
            }
        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)

    # Prepare optimizer
    (
        model,
        optimizer,
        lr_scheduler,
        checkpoint,
        global_step,
        criterion,
    ) = prepare_model_and_optimizer(args, device)

    raw_train_start = None
    most_recent_ckpts_paths = []
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    test_losses = []

    pool = ProcessPoolExecutor(1)

    # Note: We loop infinitely over epochs, termination is handled via iteration count
    while True:
        thread = None
        restored_data_loader = None
        if (not args.resume_from_checkpoint or epoch > 0
                or (args.phase2 and global_step < 1) or args.init_checkpoint):
            files = [
                os.path.join(args.input_dir, f)
                for f in os.listdir(args.input_dir)
                if os.path.isfile(os.path.join(args.input_dir, f))
                and "training" in f
            ]
            files.sort()
            num_files = len(files)
            random.Random(args.seed + epoch).shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint["files"][0]
            files = checkpoint["files"][1:]
            args.resume_from_checkpoint = False
            num_files = len(files)
            # may not exist in all checkpoints
            epoch = checkpoint.get("epoch", 0)
            restored_dataloader = checkpoint.get("data_loader", None)

        shared_file_list = {}

        if smp.is_initialized():
            dpsize = smp.dp_size()
            dprank = smp.dp_rank()
        elif torch.distributed.is_initialized():
            dpsize = get_world_size()
            dprank = get_rank()
        else:
            dpsize = 1
            dprank = 0
        dparallel = dpsize > 1
        if dparallel and dpsize > num_files:
            remainder = dpsize % num_files
            data_file = files[(f_start_id * dpsize + dprank +
                               remainder * f_start_id) % num_files]
        else:
            data_file = files[(f_start_id * dpsize + dprank) % num_files]

        previous_file = data_file

        if restored_data_loader is None:
            train_data = pretraining_dataset(data_file,
                                             args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(
                train_data,
                sampler=train_sampler,
                batch_size=args.train_batch_size * args.n_gpu,
                num_workers=4,
                worker_init_fn=worker_init,
                pin_memory=True,
                drop_last=True,
            )
            # shared_file_list["0"] = (train_dataloader, data_file)
        else:
            train_dataloader = restored_data_loader
            restored_data_loader = None

        overflow_buf = None
        if args.allreduce_post_accumulation:
            overflow_buf = torch.cuda.IntTensor([0])

        for f_id in range(f_start_id + 1, len(files)):
            if get_world_size() > num_files:
                data_file = files[(f_id * get_world_size() + get_rank() +
                                   remainder * f_id) % num_files]
            else:
                data_file = files[(f_id * get_world_size() + get_rank()) %
                                  num_files]

            previous_file = data_file

            dataset_future = pool.submit(
                create_pretraining_dataset,
                data_file,
                args.max_predictions_per_seq,
                shared_file_list,
                args,
                worker_init,
            )

            train_iter = (tqdm(train_dataloader,
                               desc="Iteration",
                               disable=args.disable_progress_bar)
                          if is_main_process() else train_dataloader)

            if raw_train_start is None:
                raw_train_start = time.time()

            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                if args.do_train:
                    from smdistributed.modelparallel.test.torch.utils import dump_model, verify

                    model.train()
                    if args.smp > 0:
                        loss_mbs = smp_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                        loss = loss_mbs.reduce_mean()
                        if smp.rank() == 0:
                            print("Loss:", loss.item())
                    else:
                        loss = train_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                    divisor = 1
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = (int(
                            training_steps / args.gradient_accumulation_steps)
                                          % args.log_freq)
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        if torch.distributed.is_initialized():
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = loss.item()
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        average_loss = 0

                    if (global_step >= args.steps_this_run or training_steps %
                        (args.num_steps_per_checkpoint *
                         args.gradient_accumulation_steps) == 0
                            or timeout_sent):
                        if smp.dp_rank() == 0 and not args.skip_checkpoint:
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step),
                                )
                            if args.do_train:
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Checkpoint mp_rank specific state
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=True)

                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3 and (
                                        args.smp == 0 or smp.dp_rank() == 0):
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed +
                                              f"_{smp.mp_rank()}")

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            if smp.dp_rank() == 0 and args.save_full:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Save a single checkpoint containing entire model parameters
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=False)
                            smp.barrier()
                            if smp.local_rank() == 0:
                                print(f"Start syncing model checkpoints to s3")
                                base_s3_path = os.path.dirname(
                                    os.path.dirname(
                                        os.getenv("SM_MODULE_DIR", "")))
                                curr_host = os.getenv("SM_CURRENT_HOST")
                                full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
                                sync_local_checkpoints_to_s3(
                                    local_path=args.output_dir,
                                    s3_path=full_s3_path)
                                print(
                                    f"Finished syncing model checkpoints to s3"
                                )
                            return args, final_loss, train_time_raw, global_step
                else:
                    model.eval()
                    with torch.no_grad():
                        loss = test_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            criterion,
                            step,
                        )
                        print(f"global_step {global_step} Test Loss:", loss)
                        test_losses.append(loss)
                    global_step += 1
                    if global_step >= args.steps_this_run:
                        return sum(test_losses) / len(test_losses)

            del train_dataloader
            # thread.join()
            # Make sure pool has finished and switch train_dataloader
            # NOTE: Will block until complete
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1