Beispiel #1
0
     continue  # we need to skip steps until we reach the resumed step
 loss = model(batch, labels=batch, use_cache=False).loss
 log_metrics(
     step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
 )
 loss = loss / args.gradient_accumulation_steps
 if step % args.gradient_accumulation_steps != 0:
     # Prevent backward from doing gradient all_reduce in every step
     if accelerator.distributed_type == DistributedType.MULTI_GPU:
         with model.no_sync():
             accelerator.backward(loss)
     else:
         accelerator.backward(loss)
 else:
     accelerator.backward(loss)
     accelerator.clip_grad_norm_(model.parameters(), 1.0)
     optimizer.step()
     lr_scheduler.step()
     optimizer.zero_grad()
     completed_steps += 1
     elapsed_time = time.time() - t_start
     tflops = compute_tflops(elapsed_time, accelerator, args)
     log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
     t_start = time.time()
 if step % args.save_checkpoint_steps == 0:
     logger.info("Evaluating and saving model checkpoint")
     eval_loss, perplexity = evaluate(args)
     log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
     accelerator.wait_for_everyone()
     save_dir = os.path.join(args.save_dir, f"step_{step}")
     accelerator.save_state(save_dir)
Beispiel #2
0
def main():
    args = arg_parser()
    # turn on benchmark mode
    torch.backends.cudnn.benchmark = True

    accelerator = Accelerator(fp16=args.use_fp16)

    if accelerator.is_main_process:
        # setup logger
        os.makedirs(args.log_dir, exist_ok=True)
        time_stamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())
        logger = get_root_logger(logger_name='MOD', log_file=os.path.join(
            args.log_dir, f'{time_stamp}.log'))
        writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tf_logs'))
        # log env info
        logger.info('--------------------Env info--------------------')
        for key, value in sorted(collect_env().items()):
            logger.info(str(key) + ': ' + str(value))
        # log args
        logger.info('----------------------Args-----------------------')
        for key, value in sorted(vars(args).items()):
            logger.info(str(key) + ': ' + str(value))
        logger.info('---------------------------------------------------')

    # train_dataset = MOD(root=args.root, annfile=args.train_annfile)
    train_dataset = MOD_3d(
        root=args.root, annfile=args.train_annfile, clip_length=args.clip_length)
    train_dataloader = DataLoader(train_dataset, batch_size=args.samples_per_gpu,
                                  shuffle=True, num_workers=args.num_workers, pin_memory=True)
    # val dataloader
    # val_dataset = MOD(root=args.root, annfile=args.val_annfile, val=True)
    val_dataset = MOD_3d(root=args.root, annfile=args.val_annfile,
                         val=True, clip_length=args.clip_length)
    val_dataloader = DataLoader(val_dataset, batch_size=args.samples_per_gpu,
                                shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # define model
    # model = TinyUNet(
    #     n_channels=1, n_classes=train_dataset.num_classes, upsample='bilinear')
    # replace2dwith3d(model=model)
    model = TinyUNet3d(n_channels=1, n_classes=2)
    # optimizer
    init_lr = args.base_lr*dist.get_world_size()*args.samples_per_gpu/16
    optimizer = optim.SGD(model.parameters(), lr=init_lr,
                          weight_decay=1e-4, momentum=0.9)
    # recover states
    start_epoch = 1
    if args.resume is not None:
        ckpt: dict() = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        start_epoch = ckpt['epoch']+1
        if accelerator.is_main_process:
            logger.info(f"Resume from epoch {start_epoch-1}...")
    else:
        if accelerator.is_main_process:
            logger.info("Start training from scratch...")
    # convert BatchNorm to SyncBatchNorm
    model = SyncBatchNorm.convert_sync_batchnorm(model)
    # prepare to be DDP models
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader)
    # closed_form lr_scheduler
    total_steps = len(train_dataloader)*args.epochs
    resume_step = len(train_dataloader)*(start_epoch-1)
    lr_scheduler = ClosedFormCosineLRScheduler(
        optimizer, init_lr, total_steps, resume_step)
    # loss criterion
    criterion = CrossEntropyLoss(weight=torch.tensor([1., 10.]), ignore_index=255).to(
        accelerator.device)  # 
    # training
    # Best acc
    best_miou = 0.
    for e in range(start_epoch, args.epochs+1):
        model.train()
        for i, batch in enumerate(train_dataloader):
            img, mask = batch
            logits = model(img)
            loss = criterion(logits, mask)
            accelerator.backward(loss)
            # clip grad if true
            if args.clip_grad_norm is not None:
                grad_norm = accelerator.clip_grad_norm_(
                    model.parameters(), args.clip_grad_norm)
            optimizer.step()
            optimizer.zero_grad()
            # sync before logging
            accelerator.wait_for_everyone()
            ## log and tensorboard
            if accelerator.is_main_process:
                if i % args.log_interval == 0:
                    writer.add_scalar('loss', loss.item(),
                                      (e-1)*len(train_dataloader)+i)
                    lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('lr', lr,
                                      (e-1)*len(train_dataloader)+i)
                    loss_str = f"loss: {loss.item():.4f}"
                    epoch_iter_str = f"Epoch: [{e}] [{i}/{len(train_dataloader)}], "
                    if args.clip_grad_norm is not None:
                        logger.info(
                            epoch_iter_str+f'lr: {lr}, '+loss_str+f', grad_norm: {grad_norm}')
                    else:
                        logger.info(epoch_iter_str+f'lr: {lr}, '+loss_str)

            lr_scheduler.step()
        if accelerator.is_main_process:
            if e % args.save_interval == 0:
                save_path = os.path.join(args.log_dir, f'epoch_{e}.pth')
                torch.save(
                    {'state_dict': model.module.state_dict(), 'epoch': e, 'args': args,
                        'optimizer': optimizer.state_dict()}, save_path)
                logger.info(f"Checkpoint has been saved at {save_path}")
        # start to evaluate
        if accelerator.is_main_process:
            logger.info("Evaluate on validation dataset")
            bar = tqdm(total=len(val_dataloader))
        model.eval()
        preds = []
        gts = []
        for batch in val_dataloader:
            img, mask = batch
            with torch.no_grad():
                logits = model(img)
                pred = accelerator.gather(logits)
                gt = accelerator.gather(mask)
            preds.append(pred)
            gts.append(gt)
            if accelerator.is_main_process:
                bar.update(accelerator.num_processes)
        if accelerator.is_main_process:
            bar.close()
            # compute metrics
            # prepare preds
            preds = torch.cat(preds)[:len(val_dataloader.dataset)]
            preds = average_preds(preds, window=args.clip_length)  # NCHW
            preds = F.softmax(preds, dim=1)
            preds = torch.argmax(preds, dim=1)  # NHW
            # prepare gts
            gts = torch.cat(gts)[:len(val_dataloader.dataset)]  # NTHW
            gts = flat_gts(gts, window=args.clip_length)  # NHW
            # accuarcy
            acc = accuarcy(preds, gts, ignore_index=0, average='micro')
            # mIoU
            miou = mIoU(preds, gts, ignore_index=0)
            logger.info(f"Accuracy on Val dataset: {acc:.4f}")
            logger.info(f"Mean IoU on Val dataset: {miou:.4f}")
            # save preds
            if miou > best_miou:
                best_miou = miou
                val_results_dir = os.path.join(
                    args.log_dir, 'best_val_results')
                os.makedirs(val_results_dir, exist_ok=True)
                imgpaths = flat_paths(val_dataset.imgpaths)
                assert preds.shape[0] == len(imgpaths)
                preds = preds.cpu().numpy()
                for i in range(preds.shape[0]):
                    imgname = imgpaths[i].split('/')[-1]
                    imgpath = os.path.join(val_results_dir, imgname)
                    result = preds[i].astype(np.uint8)
                    result[result == 1] = 255
                    result = Image.fromarray(result)
                    result.save(imgpath)
        # delete unuseful vars
        del preds
        del gts
        accelerator.wait_for_everyone()
Beispiel #3
0
def train(args):
    dataset = load_dataset('ManyTypes4TypeScript.py', ignore_verifications=True)
    accelerator = Accelerator()
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, add_prefix_space=True, use_fast=True)

    def tokenize_and_align_labels(examples):
        def divide_chunks(l1, l2, n):
            for i in range(0, len(l1), n):
                yield {'input_ids': [0] + l1[i:i + n] + [2], 'labels': [-100] + l2[i:i + n] + [-100]}

        window_size = 510
        tokenized_inputs = tokenizer(examples['tokens'], is_split_into_words=True, truncation=False,
                                     add_special_tokens=False)
        inputs_ = {'input_ids': [], 'labels': []}
        for encoding, label in zip(tokenized_inputs.encodings, examples['labels']):
            word_ids = encoding.word_ids  # Map tokens to their respective word.
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:  # Set the special tokens to -100.
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                    l = label[word_idx] if label[word_idx] is not None else -100
                    label_ids.append(l)
                else:
                    label_ids.append(-100)
                previous_word_idx = word_idx

            s_labels = set(label_ids)
            if len(s_labels) == 1 and list(s_labels)[0] == -100:
                continue
            for e in divide_chunks(encoding.ids, label_ids, window_size):
                for k, v in e.items():
                    inputs_[k].append(v)
        return inputs_

    tokenized_hf = dataset.map(tokenize_and_align_labels, batched=True, remove_columns=['id', 'tokens', 'labels'])
    label_list = tokenized_hf["train"].features[f"labels"].feature.names

    model = AutoModelForTokenClassification.from_pretrained(args.model_name, num_labels=len(label_list))

    train_dataset = tokenized_hf["train"]
    eval_dataset = tokenized_hf["test"]
    valid_dataset = tokenized_hf["validation"]
    logger = logging.getLogger(__name__)

    train_batch_size = args.train_batch_size
    eval_batch_size = args.eval_batch_size
    gradient_accumulation_steps = args.gradient_accumulation_steps
    data_collator = DataCollatorForTokenClassification(
        tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None), padding='max_length', max_length=512
    )

    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size
    )
    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=eval_batch_size)

    valid_dataloader = DataLoader(valid_dataset, collate_fn=data_collator, batch_size=eval_batch_size)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    # Use the device given by the `accelerator` object.
    device = accelerator.device
    print("Device: {0}".format(device))
    model.to(device)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, valid_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, valid_dataloader
    )

    lr_scheduler = get_scheduler(
        name='constant',  # constant because streaming dataset
        optimizer=optimizer,
        # num_warmup_steps=args.warmup_steps,
        # num_training_steps=None if args.max_steps < 0. else args.max_steps,
    )

    # Metrics - more detailed than overall accuracy in evaluator.py
    warnings.filterwarnings('ignore')
    metric = load_metric("seqeval")
    metric_unk = load_metric("seqeval")
    metric_top100 = load_metric("seqeval")

    train_total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
    eval_total_batch_size = eval_batch_size * accelerator.num_processes
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {train_total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")

    # Only show the progress bar once on each machine.
    progress_bar_train = tqdm(range(len(train_dataset) // train_total_batch_size),
                              disable=not accelerator.is_local_main_process)
    progress_bar_eval = tqdm(range(len(eval_dataset) // eval_total_batch_size),
                             disable=not accelerator.is_local_main_process)
    completed_steps = 0

    for epoch in range(args.num_train_epochs):

        if args.do_train:
            model.train()
            for step, batch in enumerate(train_dataloader):
                outputs = model(**batch)
                loss = outputs.loss
                loss = loss / gradient_accumulation_steps
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                if step % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    progress_bar_train.update(1)
                    completed_steps += 1
                    if args.max_steps > 0 and step > args.max_steps:
                        break

        if args.do_eval:
            export_predictions = []
            model.eval()
            for step, batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(input_ids=batch['input_ids'], labels=None)
                predictions = outputs.logits.argmax(dim=-1)
                labels = batch["labels"]
                predictions_gathered = accelerator.gather(predictions)
                labels_gathered = accelerator.gather(labels)
                preds, refs = get_labels(predictions_gathered, labels_gathered, label_list)
                export_predictions.extend(flatten(preds))
                preds_unk, refs_unk = get_labels(predictions_gathered, labels_gathered, label_list, score_unk=True)
                preds_100, refs_100 = get_labels(predictions_gathered, labels_gathered, label_list, top100=True)
                progress_bar_eval.update(1)
                metric.add_batch(
                    predictions=preds,
                    references=refs,
                )
                metric_unk.add_batch(
                    predictions=preds_unk,
                    references=refs_unk,
                )
                metric_top100.add_batch(
                    predictions=preds_100,
                    references=refs_100,
                )

            eval_metric = compute_metrics(metric, metric_unk, metric_top100)
            accelerator.print(f"epoch {epoch}:", eval_metric)

            enums = list(map(str, list(range(len(export_predictions)))))
            export_predictions = list(map(str, export_predictions))
            export_predictions = ["{}\t{}".format(a_, b_) for a_, b_ in zip(enums, export_predictions)]
            with open(args.output_dir + "/predictions.txt", 'w') as f:
                f.write("\n".join(export_predictions))

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)