def create_opt(self) -> torch.optim.Optimizer:
     opt = AdamW(
         self.model.parameters(),
         lr=self.args.lr,
         weight_decay=self.args.weight_decay,
     )
     if os.path.exists(self.opt_path()):
         print("loading optimizer from checkpoint...")
         opt.load_state_dict(torch.load(self.opt_path(), map_location="cpu"))
     return opt
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(
        description='20bn-jester-v1 Gesture Classification with Backpropamine')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        metavar='N',
                        help='input batch size for training (default: 8)')
    #parser.add_argument('--validation-batch-size', type=int, default=1000, metavar='N',
    #                    help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num-workers',
                        type=int,
                        default=0,
                        metavar='W',
                        help='number of workers for data loading (default: 0)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        metavar='LR',
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--dataset-dir',
                        type=str,
                        default=r"./dataset",
                        metavar='D',
                        help='dataset place (default: ./dataset)')
    #parser.add_argument('--log-interval', type=int, default=10, metavar='N',
    #                    help='how many batches to wait before logging training status')
    #parser.add_argument('--save-model', action='store_true', default=False,
    #                    help='For Saving the current Model')
    parser.add_argument('--no-resume',
                        action='store_true',
                        default=False,
                        help='switch to disables resume')
    parser.add_argument(
        '--use-lstm',
        action='store_true',
        default=False,
        help='switch to use LSTM module instead of backpropamine')
    parser.add_argument('--frame-step',
                        type=int,
                        default=2,
                        metavar='FS',
                        help='step of video frames extraction (default: 2)')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    torch.manual_seed(args.seed)

    train_data = MyDataset('train',
                           args.dataset_dir,
                           frame_step=args.frame_step)
    validation_data = MyDataset('validation',
                                args.dataset_dir,
                                frame_step=args.frame_step)
    train_dataloader = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  shuffle=True,
                                  collate_fn=collate_fn,
                                  num_workers=args.num_workers)
    validation_dataloader = DataLoader(validation_data,
                                       batch_size=args.batch_size,
                                       drop_last=True,
                                       shuffle=True,
                                       collate_fn=collate_fn,
                                       num_workers=args.num_workers)

    resume = not args.no_resume

    if resume:
        try:
            checkpoint = torch.load("checkpoint.pt")
        except FileNotFoundError:
            resume = False

    mode = 'LSTM' if args.use_lstm else 'backpropamine'
    model = Net(mode=mode).to(device)
    optimizer = AdamW(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    last_epoch, max_epoch = 0, args.epochs

    if resume:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        last_epoch = checkpoint['last_epoch']

    validator = Validator(model, validation_dataloader, device, args.dry_run)
    trainer = Trainer(model, optimizer, train_dataloader, scheduler,
                      last_epoch, max_epoch, device, validator, args.dry_run)

    print(vars(args))
    trainer()
    print("finish.")
class TrainLoop:
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model

    def _load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

        if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
            if dist.get_rank() == 0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
                self.model.load_state_dict(
                    dist_util.load_state_dict(
                        resume_checkpoint, map_location=dist_util.dev()
                    )
                )

        dist_util.sync_params(self.model.parameters())

    def _load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.master_params)

        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
        if ema_checkpoint:
            if dist.get_rank() == 0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev()
                )
                ema_params = self._state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)
        return ema_params

    def _load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
        )
        if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev()
            )
            self.opt.load_state_dict(state_dict)

    def _setup_fp16(self):
        self.master_params = make_master_params(self.model_params)
        self.model.convert_to_fp16()

    def run_loop(self):
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond = next(self.data)
            self.run_step(batch, cond)
            if self.step % self.log_interval == 0:
                logger.dumpkvs()
            if self.step % self.save_interval == 0:
                self.save()
                # Run for a finite amount of time in integration tests.
                if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                    return
            self.step += 1
        # Save the last checkpoint if it wasn't already saved.
        if (self.step - 1) % self.save_interval != 0:
            self.save()

    def run_step(self, batch, cond):
        self.forward_backward(batch, cond)
        if self.use_fp16:
            self.optimize_fp16()
        else:
            self.optimize_normal()
        self.log_step()

    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch].to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch].to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,
            )

            if last_batch or not self.use_ddp:
                losses = compute_losses()
            else:
                with self.ddp_model.no_sync():
                    losses = compute_losses()

            if isinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach()
                )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            if self.use_fp16:
                loss_scale = 2 ** self.lg_loss_scale
                (loss * loss_scale).backward()
            else:
                loss.backward()

    def optimize_fp16(self):
        if any(not th.isfinite(p.grad).all() for p in self.model_params):
            self.lg_loss_scale -= 1
            logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
            return

        model_grads_to_master_grads(self.model_params, self.master_params)
        self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
        master_params_to_model_params(self.model_params, self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth

    def optimize_normal(self):
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)

    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            sqsum += (p.grad ** 2).sum().item()
        logger.logkv_mean("grad_norm", np.sqrt(sqsum))

    def _anneal_lr(self):
        if not self.lr_anneal_steps:
            return
        frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def log_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
        if self.use_fp16:
            logger.logkv("lg_loss_scale", self.lg_loss_scale)

    def save(self):
        def save_checkpoint(rate, params):
            state_dict = self._master_params_to_state_dict(params)
            if dist.get_rank() == 0:
                logger.log(f"saving model {rate}...")
                if not rate:
                    filename = f"model{(self.step+self.resume_step):06d}.pt"
                else:
                    filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
                with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.master_params)
        for rate, params in zip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)

        if dist.get_rank() == 0:
            with bf.BlobFile(
                bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
                "wb",
            ) as f:
                th.save(self.opt.state_dict(), f)

        dist.barrier()

    def _master_params_to_state_dict(self, master_params):
        if self.use_fp16:
            master_params = unflatten_master_params(
                self.model.parameters(), master_params
            )
        state_dict = self.model.state_dict()
        for i, (name, _value) in enumerate(self.model.named_parameters()):
            assert name in state_dict
            state_dict[name] = master_params[i]
        return state_dict

    def _state_dict_to_master_params(self, state_dict):
        params = [state_dict[name] for name, _ in self.model.named_parameters()]
        if self.use_fp16:
            return make_master_params(params)
        else:
            return params
Exemple #4
0
    ],
    'weight_decay':
    0.0
}]

optimizer = AdamW(optimizer_grouped_parameters,
                  lr=args.learning_rate,
                  eps=args.adam_epsilon)
#-----------------------------------------

#-----------------------------------------
# Loading the contents of the auxiliary checkpoint and instantiating the contents if not resuming:
if aux_checkpoint:
    global_step = aux_checkpoint['global_step']
    epoch = aux_checkpoint['epoch']
    optimizer.load_state_dict(aux_checkpoint['optimizer'])
    best_acc = aux_checkpoint['best_acc']
    mlp.load_state_dict(aux_checkpoint['mlp_state_dict'])
    best_checkpoint_path = aux_checkpoint['best_checkpoint_path']
else:
    global_step = 0
    best_acc = 0.0
    epoch = 0
    best_checkpoint_path = None
#-----------------------------------------

#-----------------------------------------
# Enabling the use of dataparallel for multiple GPUs:
if args.dataparallel:
    model = nn.DataParallel(model)
#-----------------------------------------
Exemple #5
0
def train(args):
    logger = log.get_logger(__name__)

    with open(Path(args.config_base_path, args.config).with_suffix(".yaml"), 'r') as f:
        config = yaml.safe_load(f)

    train_transforms = transforms.get_train_transforms()
    val_transforms = transforms.get_val_transforms()

    logger.info("Loading the dataset...")
    if config['dataset']['name'] == 'coco_subset':
        # TODO: Look into train_transforms hiding the objects
        # Transform in such a way that this can't be the case
        train_dataset = CocoSubset(config['dataset']['coco_path'],
                                   config['dataset']['target_classes'],
                                   train_transforms,
                                   'train',
                                   config['dataset']['train_val_split'])

        val_dataset = CocoSubset(config['dataset']['coco_path'],
                                 config['dataset']['target_classes'],
                                 val_transforms,
                                 'val',
                                 config['dataset']['train_val_split'])
    else:
        raise ValueError("Dataset name not recognized or implemented")

    train_loader = DataLoader(train_dataset,
                              config['training']['batch_size'],
                              shuffle=True,
                              collate_fn=data_utils.collate_fn)

    val_loader = DataLoader(val_dataset,
                            config['training']['batch_size'],
                            shuffle=True,
                            collate_fn=data_utils.collate_fn)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint_manager = CheckpointManager(args.config, args.save_every)

    logger.info("Loading model...")
    model = models.DETR(config['dataset']['num_classes'],
                        config['model']['dim_model'],
                        config['model']['n_heads'],
                        n_queries=config['model']['n_queries'],
                        head_type=config['model']['head_type'])

    # TODO: implement scheduler
    optim = AdamW(model.parameters(), config['training']['lr'])  # pending

    if args.mode == 'pretrained':
        model.load_demo_state_dict('data/state_dicts/detr_demo.pth')
    elif args.mode == 'checkpoint':
        state_dict, optim_dict = checkpoint_manager.load_checkpoint('latest')
        model.load_state_dict(state_dict)
        optim.load_state_dict(optim_dict)

    if args.train_section == 'head':
        to_train = ['ffn']
    elif args.train_section == 'backbone':
        to_train = ['backbone', 'conv']
    else:
        to_train = ['ffn', 'backbone', 'conv', 'transformer', 'row', 'col', 'object']

    # Freeze everything but the modules that are in to_train
    for name, param in model.named_parameters():
        if not any(map(name.startswith, to_train)):
            param.requires_grad = False

    model.to(device)

    matcher = models.HungarianMatcher(config['losses']['lambda_matcher_classes'],
                                      config['losses']['lambda_matcher_giou'],
                                      config['losses']['lambda_matcher_l1'])

    loss_fn = models.DETRLoss(config['losses']['lambda_loss_classes'],
                              config['losses']['lambda_loss_giou'],
                              config['losses']['lambda_loss_l1'],
                              config['dataset']['num_classes'],
                              config['losses']['no_class_weight'])

    # writer = SummaryWriter(log_dir=Path(__file__)/'logs/tensorboard')
    # maybe image with boxes every now and then
    # maybe look into add_hparams

    logger.info("Starting training...")
    loss_hist = deque()
    loss_desc = "Loss: n/a"

    update_every_n_steps = config['training']['effective_batch_size'] // config['training']['batch_size']
    steps = 1

    starting_epoch = checkpoint_manager.current_epoch

    for epoch in range(starting_epoch, config['training']['epochs']):
        epoch_desc = f"Epoch [{epoch}/{config['training']['epochs']}]"

        for images, labels in tqdm(train_loader, f"{epoch_desc} | {loss_desc}"):
            images = images.to(device)
            labels = data_utils.labels_to_device(labels, device)

            output = model(images)
            matching_indices = matcher(output, labels)
            matching_indices = data_utils.indices_to_device(matching_indices, device)

            loss = loss_fn(output, labels, matching_indices) / update_every_n_steps
            loss_hist.append(loss.item() * update_every_n_steps)
            loss.backward()

            if steps % update_every_n_steps == 0:
                optim.step()
                optim.zero_grad()

            steps += 1

        checkpoint_manager.step(model, optim, sum(loss_hist) / len(loss_hist))

        loss_desc = f"Loss: {sum(loss_hist)/len(loss_hist)}"
        loss_hist.clear()

        if (epoch % args.eval_every == 0) and epoch != 0:
            validation_loop(model, matcher, val_loader, loss_fn, device)

    checkpoint_manager.save_checkpoint(model, optim, sum(loss_hist) / len(loss_hist))
    def fit(self):
        config = self.config

        logging.debug(json.dumps(config, indent=4, sort_keys=True))

        include_passage_masks = self.config["fusion_strategy"] == "passages"
        if self.config["dataset"] in ["nq", "trivia"]:
            fields = FusionInDecoderDataset.prepare_fields(
                pad_t=self.tokenizer.pad_token_id)
            if not config["test_only"]:
                # trivia is too large, create lightweight training dataset for it instead
                training_dataset = FusionInDecoderDatasetLight if self.config \
                    .get("use_lightweight_dataset", False) else FusionInDecoderDataset
                train = training_dataset(config["train_data"], fields=fields, tokenizer=self.tokenizer,
                                         database=self.db,
                                         transformer=config["reader_transformer_type"],
                                         cache_dir=self.config["data_cache_dir"],
                                         max_len=self.config.get("reader_max_input_length", None),
                                         context_length=self.config["context_length"],
                                         include_golden_passage=self.config["include_golden_passage_in_training"],
                                         include_passage_masks=include_passage_masks,
                                         preprocessing_truncation=self.config["preprocessing_truncation"],
                                         one_answer_per_question=self.config.get("one_question_per_epoch", False),
                                         use_only_human_answer=self.config.get("use_human_answer_only", False),
                                         is_training=True)

                val = FusionInDecoderDataset(config["val_data"], fields=fields, tokenizer=self.tokenizer,
                                             database=self.db,
                                             transformer=config["reader_transformer_type"],
                                             cache_dir=config["data_cache_dir"],
                                             max_len=self.config.get("reader_max_input_length", None),
                                             context_length=self.config["context_length"],
                                             include_passage_masks=include_passage_masks,
                                             preprocessing_truncation=self.config["preprocessing_truncation"],
                                             use_only_human_answer=self.config.get("use_human_answer_only", False),
                                             is_training=False)
            test = FusionInDecoderDataset(config["test_data"], fields=fields, tokenizer=self.tokenizer,
                                          database=self.db,
                                          transformer=config["reader_transformer_type"],
                                          cache_dir=config["data_cache_dir"],
                                          max_len=self.config.get("reader_max_input_length", None),
                                          context_length=self.config["context_length"],
                                          include_passage_masks=include_passage_masks,
                                          preprocessing_truncation=self.config["preprocessing_truncation"],
                                          is_training=False)

        else:
            raise NotImplemented(f"Unknown dataset {self.config['dataset']}")

        if not config["test_only"]:
            logging.info(f"Training data examples:{len(train)}")
            logging.info(f"Validation data examples:{len(val)}")
        logging.info(f"Test data examples {len(test)}")

        if not config["test_only"]:
            train_iter = Iterator(train,
                                  shuffle=training_dataset != FusionInDecoderDatasetLight,
                                  sort=False,  # do not sort!
                                  batch_size=1, train=True,
                                  repeat=False, device=self.device)
            val_iter = Iterator(val,
                                sort=False, shuffle=False,
                                batch_size=1,
                                repeat=False, device=self.device)
        test_iter = Iterator(test,
                             sort=False, shuffle=False,
                             batch_size=1,
                             repeat=False, device=self.device)
        logging.info("Loading model...")
        if config.get("resume_training", False) or config.get("pre_initialize", False):
            if config.get("resume_training", False):
                logging.info("Resuming training...")
            if not "resume_checkpoint" in config:
                config["resume_checkpoint"] = config["pretrained_reader_model"]
            model = torch.load(config["resume_checkpoint"], map_location=self.device)
        else:
            model = torch.load(config["model"], map_location=self.device) \
                if self.config["test_only"] and "model" in config else \
                T5FusionInDecoder.from_pretrained(config).to(self.device)

        logging.info(f"Resizing token embeddings to length {len(self.tokenizer)}")
        model.resize_token_embeddings(len(self.tokenizer))

        logging.info(f"Model has {count_parameters(model)} trainable parameters")
        logging.info(f"Trainable parameter checksum: {sum_parameters(model)}")
        param_sizes, param_shapes = report_parameters(model)
        param_sizes = "\n'".join(str(param_sizes).split(", '"))
        param_shapes = "\n'".join(str(param_shapes).split(", '"))
        logging.debug(f"Model structure:\n{param_sizes}\n{param_shapes}\n")

        if not config["test_only"]:
            # Init optimizer
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.config["weight_decay"],
                },
                {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                 "weight_decay": 0.0},
            ]
            if config["optimizer"] == "adamw":
                optimizer = AdamW(optimizer_grouped_parameters,
                                  lr=self.config["learning_rate"],
                                  eps=self.config["adam_eps"])
            elif config["optimizer"] == "adam":
                optimizer = Adam(optimizer_grouped_parameters,
                                 lr=self.config["learning_rate"],
                                 eps=self.config["adam_eps"])
            else:
                raise ValueError("Unsupported optimizer")

            if config.get("resume_checkpoint", False):
                optimizer.load_state_dict(model.optimizer_state_dict)

            # Init scheduler
            if "scheduler_warmup_steps" in self.config or "warmup_proportion" in self.config:
                t_total = self.config["max_steps"]
                warmup_steps = round(
                    self.config[
                        "scheduler_warmup_proportion"] * t_total) if "scheduler_warmup_proportion" in self.config else \
                    self.config["scheduler_warmup_steps"]
                scheduler = self.init_scheduler(
                    optimizer,
                    num_warmup_steps=warmup_steps,
                    num_training_steps=t_total,
                    last_step=get_model(model).training_steps - 1
                )
                logging.info(f"Scheduler: warmup steps: {warmup_steps}, total_steps: {t_total}")
            else:
                scheduler = None

            if config["lookahead_optimizer"]:
                optimizer = Lookahead(optimizer, k=10, alpha=0.5)

        if not config["test_only"]:
            start_time = time.time()
            try:
                it = 0
                while get_model(model).training_steps < self.config["max_steps"]:
                    logging.info(f"Epoch {it}")
                    train_loss = self.train_epoch(model=model,
                                                  data_iter=train_iter,
                                                  val_iter=val_iter,
                                                  optimizer=optimizer,
                                                  scheduler=scheduler)
                    logging.info(f"Training loss: {train_loss:.5f}")
                    it += 1

            except KeyboardInterrupt:
                logging.info('-' * 120)
                logging.info('Exit from training early.')
            finally:
                logging.info(f'Finished after {(time.time() - start_time) / 60} minutes.')
                if hasattr(self, "best_ckpt_name"):
                    logging.info(f"Loading best checkpoint {self.best_ckpt_name}")
                    model = torch.load(self.best_ckpt_name, map_location=self.device)
        logging.info("#" * 50)
        logging.info("Validating on the test data")
        self.validate(model, test_iter)
Exemple #7
0
def pretrain(data, stats=None):
    # fine tuning
    dataloader = DataLoader(data, batch_size=1, shuffle=True)
    del data
    ## optimizer and scheduler ##
    t_total = len(
        dataloader) // opts.gradient_accumulation_steps * opts.num_train_epochs

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        opts.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=opts.lr, eps=opts.eps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=opts.warmup_steps,
        num_training_steps=t_total)
    # loading optimizer settings
    if (opts.model_name_or_path and os.path.isfile(
            os.path.join(opts.model_name_or_path, "pretrain_optimizer.pt"))
            and os.path.isfile(
                os.path.join(opts.model_name_or_path, "scheduler.pt"))):
        # load optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(
                os.path.join(opts.model_name_or_path,
                             "pretrain_optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(
                os.path.join(opts.model_name_or_path,
                             "pretrain_scheduler.pt")))
    # track stats
    if stats is not None:
        global_step = max(stats.keys())
        epochs_trained = global_step // (len(dataloader) //
                                         opts.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(dataloader) // opts.gradient_accumulation_steps)
        print("Resuming Training ... ")
    else:
        stats = {}
        global_step, epochs_trained, steps_trained_in_current_epoch = 0, 0, 0
    tr_loss, logging_loss = 0.0, 0.0
    # very important: set model to TRAINING mode
    model.zero_grad()
    model.train()
    print("Re-sizing model ... ")
    model.resize_token_embeddings(len(tokenizer))
    start_time = time.time()
    for epoch in range(epochs_trained, opts.num_train_epochs):
        data_iter = iter(dataloader)
        for step in range(len(dataloader)):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                batch = data_iter.next()
                continue
            ### step ###
            batch = data_iter.next()
            loss = fit_on_batch(batch)
            del batch
            # logging (new data only)
            tr_loss += loss.item()

            # gradient accumulation
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opts.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # reporting
                if global_step % opts.logging_steps == 0:
                    stats[global_step] = {
                        'pretrain_loss':
                        (tr_loss - logging_loss) / opts.logging_steps,
                        'pretrain_lr': scheduler.get_last_lr()[-1]
                    }
                    logging_loss = tr_loss

                    elapsed_time = time.strftime(
                        "%M:%S", time.gmtime(time.time() - start_time))
                    print(
                        'Epoch: %d | Iter: [%d/%d] | loss: %.3f | lr: %s | time: %s'
                        %
                        (epoch, global_step, t_total,
                         stats[global_step]['pretrain_loss'],
                         str(stats[global_step]['pretrain_lr']), elapsed_time))
                    start_time = time.time()

                if global_step % opts.save_steps == 0:
                    print("Saving stuff ... ")
                    checkpoint(model,
                               tokenizer,
                               optimizer,
                               scheduler,
                               stats,
                               title="pretrain_")
                    plot_losses(stats, title='pretrain_loss')
                    plot_losses(stats, title='pretrain_lr')
                    print("Done.")

    return stats
	def train(self):
		# get dataloader
		train_dataloader, _ = self.data2loader(self.args['train'], mode='train', batch_size=self.args['batch_size'])
		# optimizer and scheduler

		param_optimizer = list(self.model.named_parameters())
		other_parameters = [(n, p) for n, p in param_optimizer if 'crf' not in n]
		no_decay = ['bias', 'gamma', 'beta']
		optimizer_grouped_parameters = [
			{'params': [p for n, p in other_parameters if not any(nd in n for nd in no_decay)],
				'weight_decay_rate': 0.01},
			{'params': [p for n, p in other_parameters if any(nd in n for nd in no_decay)],
				'weight_decay_rate': 0.0},
			{'params':[p for n, p in param_optimizer if 'crf.transitions' == n], 'lr':3e-2}
		]

		optimizer = AdamW(optimizer_grouped_parameters, lr=self.args['learning_rate'], eps=1e-8)

		if self.args['load_model'] > 0:
			optimizer.load_state_dict(torch.load('models/Opt' + str(self.args['load_model'])))
			print('load optimizer success')
		
		total_steps = 1000#len(train_dataloader) * num_epoches
		if self.args['load_model'] <= 0:
			last_epoch = -1
		else:
			last_epoch = self.args['load_model']
		scheduler = get_linear_schedule_with_warmup(
		    optimizer,
		    num_warmup_steps=0,
		    num_training_steps=total_steps,
		    last_epoch=last_epoch
		)

		# training
		top = 0
		stop = 0
		best_model = None
		start_time = time()
		for i in range(self.args['num_epoches']):
			self.model.train()

			losses = 0
			for idx, batch_data in enumerate(train_dataloader):
				batch_data = tuple(i.to(device) for i in batch_data)
				ids, masks, labels = batch_data

				self.model.zero_grad()
				loss = self.model(ids, masks=masks, labels=labels)

				# process loss
				loss.backward()
				losses += loss.item()

				# tackle exploding gradients
				torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=self.args['max_grad_norm'])

				optimizer.step()

			scheduler.step()

			F0 = None
			if (i+1+self.args['load_model']) % 20 == 0:
				F0, _ = self.evaluate(self.args['train'])
			F1, loss = self.evaluate(self.args['valid'])
			F2, loss2 = self.evaluate(self.args['test'])

			if F1+F2 > top:
				top = F1 + F2
				# torch.save(self.model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1))
				best_model = copy.deepcopy(self.model)
				print('save new top', top)
				stop = 0
			else:
				if stop > 7:
					torch.save(best_model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1))
					return
				stop += 1
				
			print('Epoch', i+self.args['load_model']+1, losses/len(train_dataloader), loss, 'F1', F1, F2, F0, time()-start_time)

			if (i+1+self.args['load_model']) % self.args['save_epoch'] == 0:
				torch.save(self.model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1))
				# torch.save(optimizer.state_dict(), 'models/Opt' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1))
			start_time = time()
Exemple #9
0
class Trainer:
    """
    Handles model training and evaluation.
    
    Arguments:
    ----------
    config: A dictionary of training parameters, likely from a .yaml
    file
    
    model: A pytorch segmentation model (e.g. DeepLabV3)
    
    trn_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a training dataset
    
    val_data: A pytorch dataloader object that will return pairs of images and
    segmentation masks from a validation dataset.
    
    """
    def __init__(self, config, model, trn_data, val_data=None):
        self.config = config
        self.model = model.cuda()
        self.trn_data = DataFetcher(trn_data)
        self.val_data = val_data

        #create the optimizer
        if config['optim'] == 'SGD':
            self.optimizer = SGD(model.parameters(),
                                 lr=config['lr'],
                                 momentum=config['momentum'],
                                 weight_decay=config['wd'])
        elif config['optim'] == 'AdamW':
            self.optimizer = AdamW(
                model.parameters(), lr=config['lr'],
                weight_decay=config['wd'])  #momentum is default
        else:
            optim = config['optim']
            raise Exception(
                f'Optimizer {optim} is not supported! Must be SGD or AdamW')

        #create the learning rate scheduler
        schedule = config['lr_policy']
        if schedule == 'OneCycle':
            self.scheduler = OneCycleLR(self.optimizer,
                                        config['lr'],
                                        total_steps=config['iters'])
        elif schedule == 'MultiStep':
            self.scheduler = MultiStepLR(self.optimizer,
                                         milestones=config['lr_decay_epochs'])
        elif schedule == 'Poly':
            func = lambda iteration: (1 - (iteration / config['iters'])
                                      )**config['power']
            self.scheduler = LambdaLR(self.optimizer, func)
        else:
            lr_policy = config['lr_policy']
            raise Exception(
                f'Policy {lr_policy} is not supported! Must be OneCycle, MultiStep or Poly'
            )

        #create the loss criterion
        if config['num_classes'] > 1:
            #load class weights if they were given in the config file
            if 'class_weights' in config:
                weight = torch.Tensor(config['class_weights']).float().cuda()
            else:
                weight = None

            self.criterion = nn.CrossEntropyLoss(weight=weight).cuda()
        else:
            self.criterion = nn.BCEWithLogitsLoss().cuda()

        #define train and validation metrics and class names
        class_names = config['class_names']

        #make training metrics using the EMAMeter. this meter gives extra
        #weight to the most recent metric values calculated during training
        #this gives a better reflection of how well the model is performing
        #when the metrics are printed
        trn_md = {
            name: metric_lookup[name](EMAMeter())
            for name in config['metrics']
        }
        self.trn_metrics = ComposeMetrics(trn_md, class_names)
        self.trn_loss_meter = EMAMeter()

        #the only difference between train and validation metrics
        #is that we use the AverageMeter. this is because there are
        #no weight updates during evaluation, so all batches should
        #count equally
        val_md = {
            name: metric_lookup[name](AverageMeter())
            for name in config['metrics']
        }
        self.val_metrics = ComposeMetrics(val_md, class_names)
        self.val_loss_meter = AverageMeter()

        self.logging = config['logging']

        #now, if we're resuming from a previous run we need to load
        #the state for the model, optimizer, and schedule and resume
        #the mlflow run (if there is one and we're using logging)
        if config['resume']:
            self.resume(config['resume'])
        elif self.logging:
            #if we're not resuming, but are logging, then we
            #need to setup mlflow with a new experiment
            #everytime that Trainer is instantiated we want to
            #end the current active run and let a new one begin
            mlflow.end_run()

            #extract the experiment name from config so that
            #we know where to save our files, if experiment name
            #already exists, we'll use it, otherwise we create a
            #new experiment
            mlflow.set_experiment(self.config['experiment_name'])

            #add the config file as an artifact
            mlflow.log_artifact(config['config_file'])

            #we don't want to add everything in the config
            #to mlflow parameters, we'll just add the most
            #likely to change parameters
            mlflow.log_param('lr_policy', config['lr_policy'])
            mlflow.log_param('optim', config['optim'])
            mlflow.log_param('lr', config['lr'])
            mlflow.log_param('wd', config['wd'])
            mlflow.log_param('bsz', config['bsz'])
            mlflow.log_param('momentum', config['momentum'])
            mlflow.log_param('iters', config['iters'])
            mlflow.log_param('epochs', config['epochs'])
            mlflow.log_param('encoder', config['encoder'])
            mlflow.log_param('finetune_layer', config['finetune_layer'])
            mlflow.log_param('pretraining', config['pretraining'])

    def resume(self, checkpoint_fpath):
        """
        Sets model parameters, scheduler and optimizer states to the
        last recorded values in the given checkpoint file.
        """
        checkpoint = torch.load(checkpoint_fpath, map_location='cpu')
        self.model.load_state_dict(checkpoint['state_dict'])

        if not self.config['restart_training']:
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        if self.logging and 'run_id' in checkpoint:
            mlflow.start_run(run_id=checkpoint['run_id'])

        print(f'Loaded state from {checkpoint_fpath}')
        print(f'Resuming from epoch {self.scheduler.last_epoch}...')

    def log_metrics(self, step, dataset):
        #get the corresponding losses and metrics dict for
        #either train or validation sets
        if dataset == 'train':
            losses = self.trn_loss_meter
            metric_dict = self.trn_metrics.metrics_dict
        elif dataset == 'valid':
            losses = self.val_loss_meter
            metric_dict = self.val_metrics.metrics_dict

        #log the last loss, using the dataset name as a prefix
        mlflow.log_metric(dataset + '_loss', losses.avg, step=step)

        #log all the metrics in our dict, using dataset as a prefix
        metrics = {}
        for k, v in metric_dict.items():
            values = v.meter.avg
            for class_name, val in zip(self.trn_metrics.class_names, values):
                metrics[dataset + '_' + class_name + '_' + k] = float(
                    val.item())

        mlflow.log_metrics(metrics, step=step)

    def train(self):
        """
        Defines a pytorch style training loop for the model withtqdm progress bar
        for each epoch and handles printing loss/metrics at the end of each epoch.
        
        epochs: Number of epochs to train model
        train_iters_per_epoch: Number of training iterations is each epoch. Reducing this 
        number will give more frequent updates but result in slower training time.
        
        Results:
        ----------
        
        After train_iters_per_epoch iterations are completed, it will evaluate the model
        on val_data if there is any, then prints loss and metrics for train and validation
        datasets.
        """

        #set the inner and outer training loop as either
        #iterations or epochs depending on our scheduler
        if self.config['lr_policy'] != 'MultiStep':
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['iters']
            iters_per_epoch = 1
            outer_loop = tqdm(range(last_epoch, total_epochs + 1),
                              file=sys.stdout,
                              initial=last_epoch,
                              total=total_epochs)
            inner_loop = range(iters_per_epoch)
        else:
            last_epoch = self.scheduler.last_epoch + 1
            total_epochs = self.config['epochs']
            iters_per_epoch = len(self.trn_data)
            outer_loop = range(last_epoch, total_epochs + 1)
            inner_loop = tqdm(range(iters_per_epoch), file=sys.stdout)

        #determine the epochs at which to print results
        eval_epochs = total_epochs // self.config['num_prints']
        save_epochs = total_epochs // self.config['num_save_checkpoints']

        #the cudnn.benchmark flag speeds up performance
        #when the model input size is constant. See:
        #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        cudnn.benchmark = True

        #perform training over the outer and inner loops
        for epoch in outer_loop:
            for iteration in inner_loop:
                #load the next batch of training data
                images, masks = self.trn_data.load()

                #run the training iteration
                loss, output = self._train_1_iteration(images, masks)

                #record the loss and evaluate metrics
                self.trn_loss_meter.update(loss)
                self.trn_metrics.evaluate(output, masks)

            #when we're at an eval_epoch we want to print
            #the training results so far and then evaluate
            #the model on the validation data
            if epoch % eval_epochs == 0:
                #before printing results let's record everything in mlflow
                #(if we're using logging)
                if self.logging:
                    self.log_metrics(epoch, dataset='train')

                print('\n')  #print a new line to give space from progess bar
                print(f'train_loss: {self.trn_loss_meter.avg:.3f}')
                self.trn_loss_meter.reset()
                #prints and automatically resets the metric averages to 0
                self.trn_metrics.print()

                #run evaluation if we have validation data
                if self.val_data is not None:
                    #before evaluation we want to turn off cudnn
                    #benchmark because the input sizes of validation
                    #images are not necessarily constant
                    cudnn.benchmark = False
                    self.evaluate()

                    if self.logging:
                        self.log_metrics(epoch, dataset='valid')

                    print(
                        '\n')  #print a new line to give space from progess bar
                    print(f'valid_loss: {self.val_loss_meter.avg:.3f}')
                    self.val_loss_meter.reset()
                    #prints and automatically resets the metric averages to 0
                    self.val_metrics.print()

                    #turn cudnn.benchmark back on before returning to training
                    cudnn.benchmark = True

            #update the optimizer schedule
            self.scheduler.step()

            #the last step is to save the training state if
            #at a checkpoint
            if epoch % save_epochs == 0:
                self.save_state(epoch)

    def _train_1_iteration(self, images, masks):
        #run a training step
        self.model.train()
        self.optimizer.zero_grad()

        #forward pass
        output = self.model(images)
        loss = self.criterion(output, masks)

        #backward pass
        loss.backward()
        self.optimizer.step()

        #return the loss value and the output
        return loss.item(), output.detach()

    def evaluate(self):
        """
        Evaluation method used at the end of each epoch. Not intended to
        generate predictions for validation dataset, it only returns average loss
        and stores metrics for validaiton dataset.
        
        Use Validator class for generating masks on a dataset.
        """
        #set the model into eval mode
        self.model.eval()

        val_iter = DataFetcher(self.val_data)
        for _ in range(len(val_iter)):
            with torch.no_grad():
                #load batch of data
                images, masks = val_iter.load()
                output = self.model.eval()(images)
                loss = self.criterion(output, masks)
                self.val_loss_meter.update(loss.item())
                self.val_metrics.evaluate(output.detach(), masks)

        #loss and metrics are updated inplace, so there's nothing to return
        return None

    def save_state(self, epoch):
        """
        Saves the self.model state dict
        
        Arguments:
        ------------
        
        save_path: Path of .pt file for saving
        
        Example:
        ----------
        
        trainer = Trainer(...)
        trainer.save_model(model_path + 'new_model.pt')
        """

        #save the state together with the norms that we're using
        state = {
            'state_dict': self.model.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'norms': self.config['training_norms']
        }

        if self.logging:
            state['run_id'] = mlflow.active_run().info.run_id

        #the last step is to create the name of the file to save
        #the format is: name-of-experiment_pretraining_epoch.pth
        model_dir = self.config['model_dir']
        exp_name = self.config['experiment_name']
        pretraining = self.config['pretraining']
        ft_layer = self.config['finetune_layer']

        if self.config['lr_policy'] != 'MultiStep':
            total_epochs = self.config['iters']
        else:
            total_epochs = self.config['epochs']

        if os.path.isfile(pretraining):
            #this is slightly clunky, but it handles the case
            #of using custom pretrained weights from a file
            #usually there aren't any '.'s other than the file
            #extension
            pretraining = pretraining.split('/')[-2]  #.split('.')[0]

        save_path = os.path.join(
            model_dir,
            f'{exp_name}-{pretraining}_ft_{ft_layer}_epoch{epoch}_of_{total_epochs}.pth'
        )
        torch.save(state, save_path)
Exemple #10
0
class Model:
    def __init__(self, epochs=50, fc=FC_62):
        self.epochs = epochs
        self.model = CNN(fc)
        self.model.to(device)
        self.num_epochs = epochs
        self.epochs = 0
        self.loss = 0
        self.optimizer = AdamW(params=self.model.parameters())
        self.loss_fn = nn.CrossEntropyLoss()
        self.transform2 = [
            # transforms.CenterCrop(256),
            # Crop(28),
            # transforms.Resize(256),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epochs = checkpoint['epoch']
        self.loss = checkpoint['loss']
        print(f'\nmodel loaded from path : {path}')

    def save(self, epoch, model, optimizer, loss, path):
        save_path = root_dir + '/models/'
        if os.path.isdir(save_path) == False:
            os.makedirs(save_path)
        path = save_path + path
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, path)
        print(f'\nsaved model to path : {path}')

    def test(self, testloader, progress, type='validation'):
        print(f'Starting testing on {type} dataset')
        print('-------------------------------')
        correct, total = 0, 0
        with torch.no_grad():
            for i, data in enumerate(testloader, 0):
                inputs, targets = data
                inputs, targets = inputs.to(device), targets.to(device)

                outputs = self.model(inputs)

                _, predicted = torch.max(outputs.data, 1)
                # print(predicted)
                # print(targets)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                progress.update(self.batch_size)

            print(
                f'\nAccuracy on {type} dataset : {correct} / {total} = {100.0 * correct / total}'
            )
            print('--------------------------------')

            return 100.0 * correct / total

    def train(self, trainloader, epoch, progress):
        print(f'\nStarting epoch {epoch+1}')
        current_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            self.optimizer.zero_grad()

            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            loss.backward()

            self.optimizer.step()
            current_loss += loss.item()
            progress.update(self.batch_size)

        print(f'\nloss at epoch {epoch + 1} : {current_loss}')
        return current_loss

    def train_validate(self,
                       name,
                       mnist=False,
                       batch_size=64,
                       validation_split=0,
                       save_name=None):
        self.batch_size = batch_size

        if save_name is None:
            save_name = name

        progress = None
        np.random.seed(42)

        epochs_plot = []
        accuracy_plot = []
        loss_plot = []

        for epoch in range(0, self.num_epochs):
            if mnist:
                self.transform1 = [
                    transforms.RandomRotation(degrees=10),
                ]
                train_data = torchvision.datasets.MNIST(
                    'mnist',
                    download=True,
                    transform=transforms.Compose(self.transform1 +
                                                 self.transform2))
                trainloader = torch.utils.data.DataLoader(
                    train_data, batch_size=self.batch_size, num_workers=2)
                dataset_size = len(trainloader.dataset)
            else:
                data = get_data_set(name)
                dataset_size = len(data)
                ids = list(range(dataset_size))
                split = int(np.floor(validation_split * dataset_size))
                np.random.shuffle(ids)
                train_ids, val_ids = ids[split:], ids[:split]

                train_subsampler = torch.utils.data.SubsetRandomSampler(
                    train_ids)
                test_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

                trainloader = torch.utils.data.DataLoader(
                    data,
                    batch_size=batch_size,
                    sampler=train_subsampler,
                    num_workers=2)
                testloader = torch.utils.data.DataLoader(
                    data,
                    batch_size=batch_size,
                    sampler=test_subsampler,
                    num_workers=2)
            if progress is None:
                progress = tqdm.tqdm(total=(2 + validation_split) *
                                     dataset_size * self.num_epochs,
                                     position=0,
                                     leave=True)
            current_loss = self.train(trainloader, epoch, progress)
            accuracy = self.test(trainloader, progress, 'train')
            if validation_split:
                self.test(testloader, progress, 'validation')
            epochs_plot.append(epoch)
            accuracy_plot.append(accuracy)
            loss_plot.append(current_loss)
            self.save(epoch, self.model, self.optimizer, current_loss,
                      f'{save_name}-{epoch}.pth')
        return epochs_plot, accuracy_plot, loss_plot

    def test_mnist(self):
        test_data = torchvision.datasets.MNIST('mnist',
                                               False,
                                               download=True,
                                               transform=transforms.Compose(
                                                   self.transform2))
        testloader = torch.utils.data.DataLoader(test_data,
                                                 batch_size=self.batch_size,
                                                 num_workers=2)
        progress = tqdm.tqdm(total=len(testloader.dataset),
                             position=0,
                             leave=True)
        self.test(testloader, progress, 'test')
Exemple #11
0
class Detector(object):
    def __init__(self, cfg):
        self.device = cfg["device"]
        self.model = Models().get_model(cfg["network"]) # cfg.network
        self.model.to(self.device)
        params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW(params, lr=0.00001)
        self.lr_scheduler = OneCycleLR(self.optimizer,
                                       max_lr=1e-4,
                                       epochs=cfg["nepochs"],
                                       steps_per_epoch=169,  # len(dataloader)/accumulations
                                       div_factor=25,  # for initial lr, default: 25
                                       final_div_factor=1e3,  # for final lr, default: 1e4
                                       )

    def fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"\rTrain iteration: [{i+1}/{len(data_loader)}]", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def mixup_fit(self, data_loader, accumulation_steps=4, wandb=None):
        self.model.train()
        torch.cuda.empty_cache()
        #     metric_logger = utils.MetricLogger(delimiter="  ")
        #     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        avg_loss = MetricLogger('scalar')
        total_loss = MetricLogger('dict')
        #lr_log = MetricLogger('list')

        self.optimizer.zero_grad()
        device = self.device

        for i, (batch1, batch2) in enumerate(data_loader):
            images1, targets1 = batch1
            images2, targets2 = batch2
            images = mixup_images(images1, images2)
            targets = merge_targets(targets1, targets2)
            del images1, images2, targets1, targets2, batch1, batch2

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = self.model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.detach().item()
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            losses.backward()
            if (i+1) % accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                    #lr_log.update(self.lr_scheduler.get_last_lr())


            print(f"Train iteration: [{i+1}/{674}]\r", end="")
            avg_loss.update(loss_value)
            total_loss.update(loss_dict)

            # metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
            # metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        print()
        #print(loss_dict)
        return {"train_avg_loss": avg_loss.avg}, total_loss.avg


    def evaluate(self, val_dataloader):
        device = self.device
        torch.cuda.empty_cache()
        # self.model.to(device)
        self.model.eval()
        mAp_logger = MetricLogger('list')
        with torch.no_grad():
            for (j, batch) in enumerate(val_dataloader):
                print(f"\rValidation: [{j+1}/{len(val_dataloader)}]", end="")
                images, targets = batch
                del batch
                images = [img.to(device) for img in images]
                # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                predictions = self.model(images)#, targets)
                for i, pred in enumerate(predictions):
                    probas = pred["scores"].detach().cpu().numpy()
                    mask = probas > 0.6
                    preds = pred["boxes"].detach().cpu().numpy()[mask]
                    gts = targets[i]["boxes"].detach().cpu().numpy()
                    score, scores = map_score(gts, preds, thresholds=[.5, .55, .6, .65, .7, .75])
                    mAp_logger.update(scores)
            print()
        return {"validation_mAP_score": mAp_logger.avg}

    def get_checkpoint(self):
        self.model.eval()
        model_state = self.model.state_dict()
        optimizer_state = self.optimizer.state_dict()
        checkpoint = {'model_state_dict': model_state,
                      'optimizer_state_dict': optimizer_state
                      }
        # if self.lr_scheduler:
        #     scheduler_state = self.lr_scheduler.state_dict()
        #     checkpoint['lr_scheduler_state_dict'] = scheduler_state

        return checkpoint

    def load_checkpoint(self, checkpoint):
        self.model.eval()
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
Exemple #12
0
def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_classifier_and_diffusion(
        **args_to_dict(args,
                       classifier_and_diffusion_defaults().keys()))
    model.to(dist_util.dev())
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion)

    resume_step = 0
    if args.resume_checkpoint:
        resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
        if dist.get_rank() == 0:
            logger.log(
                f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
            )
            model.load_state_dict(
                dist_util.load_state_dict(args.resume_checkpoint,
                                          map_location=dist_util.dev()))

    # Needed for creating correct EMAs and fp16 parameters.
    dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(model=model,
                                       use_fp16=args.classifier_use_fp16,
                                       initial_lg_loss_scale=16.0)

    model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=True,
        random_crop=True,
    )
    if args.val_data_dir:
        val_data = load_data(
            data_dir=args.val_data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=True,
        )
    else:
        val_data = None

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params,
                lr=args.lr,
                weight_decay=args.weight_decay)
    if args.resume_checkpoint:
        opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint),
                                 f"opt{resume_step:06}.pt")
        logger.log(
            f"loading optimizer state from checkpoint: {opt_checkpoint}")
        opt.load_state_dict(
            dist_util.load_state_dict(opt_checkpoint,
                                      map_location=dist_util.dev()))

    logger.log("training classifier model...")

    def forward_backward_log(data_loader, prefix="train"):
        batch, extra = next(data_loader)
        labels = extra["y"].to(dist_util.dev())

        batch = batch.to(dist_util.dev())
        # Noisy images
        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

        for i, (sub_batch, sub_labels, sub_t) in enumerate(
                split_microbatches(args.microbatch, batch, labels, t)):
            logits = model(sub_batch, timesteps=sub_t)
            loss = F.cross_entropy(logits, sub_labels, reduction="none")

            losses = {}
            losses[f"{prefix}_loss"] = loss.detach()
            losses[f"{prefix}_acc@1"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=1,
                                                      reduction="none")
            losses[f"{prefix}_acc@5"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=5,
                                                      reduction="none")
            log_loss_dict(diffusion, sub_t, losses)
            del losses
            loss = loss.mean()
            if loss.requires_grad:
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))

    for step in range(args.iterations - resume_step):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr,
                            (step + resume_step) / args.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % args.eval_interval:
            with th.no_grad():
                with model.no_sync():
                    model.eval()
                    forward_backward_log(val_data, prefix="val")
                    model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (step and dist.get_rank() == 0
                and not (step + resume_step) % args.save_interval):
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

    if dist.get_rank() == 0:
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)
    dist.barrier()
def main() -> None:
    global best_loss

    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    start_epoch = 0

    vcf_reader = VCFReader(args.train_data, args.classification_map,
                           args.chromosome, args.class_hierarchy)
    vcf_writer = vcf_reader.get_vcf_writer()
    train_dataset, validation_dataset = vcf_reader.get_datasets(
        args.validation_split)
    train_sampler = BatchByLabelRandomSampler(args.batch_size,
                                              train_dataset.labels)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)

    if args.validation_split != 0:
        validation_sampler = BatchByLabelRandomSampler(
            args.batch_size, validation_dataset.labels)
        validation_loader = DataLoader(validation_dataset,
                                       batch_sampler=validation_sampler)

    kwargs = {
        'total_size': vcf_reader.positions.shape[0],
        'window_size': args.window_size,
        'num_layers': args.layers,
        'num_classes': len(vcf_reader.label_encoder.classes_),
        'num_super_classes': len(vcf_reader.super_label_encoder.classes_)
    }
    model = WindowedMLP(**kwargs)
    model.to(get_device(args))

    optimizer = AdamW(model.parameters(), lr=args.learning_rate)

    #######
    if args.resume_path is not None:
        if os.path.isfile(args.resume_path):
            print("=> loading checkpoint '{}'".format(args.resume_path))
            checkpoint = torch.load(args.resume_path)
            if kwargs != checkpoint['model_kwargs']:
                raise ValueError(
                    'The checkpoint\'s kwargs don\'t match the ones used to initialize the model'
                )
            if vcf_reader.snps.shape[0] != checkpoint['vcf_writer'].snps.shape[
                    0]:
                raise ValueError(
                    'The data on which the checkpoint was trained had a different number of snp positions'
                )
            start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #############

    if args.validate:
        validate(validation_loader, model,
                 nn.functional.binary_cross_entropy_with_logits,
                 len(vcf_reader.label_encoder.classes_),
                 len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf,
                 args)
        return

    for epoch in range(start_epoch, args.epochs + start_epoch):
        loss = train(train_loader, model,
                     nn.functional.binary_cross_entropy_with_logits, optimizer,
                     len(vcf_reader.label_encoder.classes_),
                     len(vcf_reader.super_label_encoder.classes_),
                     vcf_reader.maf, epoch, args)

        if epoch % args.save_freq == 0 or epoch == args.epochs + start_epoch - 1:
            if args.validation_split != 0:
                validation_loss = validate(
                    validation_loader, model,
                    nn.functional.binary_cross_entropy_with_logits,
                    len(vcf_reader.label_encoder.classes_),
                    len(vcf_reader.super_label_encoder.classes_),
                    vcf_reader.maf, args)
                is_best = validation_loss < best_loss
                best_loss = min(validation_loss, best_loss)
            else:
                is_best = loss < best_loss
                best_loss = min(loss, best_loss)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'model_kwargs': kwargs,
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                    'vcf_writer': vcf_writer,
                    'label_encoder': vcf_reader.label_encoder,
                    'super_label_encoder': vcf_reader.super_label_encoder,
                    'maf': vcf_reader.maf
                }, is_best, args.chromosome, args.model_name, args.model_dir)
def train(args, train_dataset, model, tokenizer, writer):

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)

    train_total = len(
        train_dataloader
    ) // args.gradient_accumulation_steps * args.num_train_epochs
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=train_total)

    if os.path.isfile(os.path.join(
            args.pretrain_model_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.pretrain_model_path, "scheduler.pt")):
        optimizer.load_state_dict(
            torch.load(os.path.join(args.pretrain_model_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.pretrain_model_path, "scheduler.pt")))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    print("***** Running training *****")

    global_step = 0
    steps_trained_in_current_epoch = 0

    if os.path.exists(args.pretrain_model_path
                      ) and "checkpoint" in args.pretrain_model_path:
        global_step = int(
            args.pretrain_model_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

    train_loss, logging_loss = 0.0, 0.0
    model.zero_grad()

    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step, batch in enumerate(train_dataloader):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "start_positions": batch[3],
                "end_positions": batch[4]
            }

            inputs["token_type_ids"] = (batch[2] if args.model_type
                                        in ["bert"] else None)
            outputs = model(**inputs)
            loss = outputs[0]

            writer.add_scalar("Train_loss", loss.item(), step)

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            pbar(step, {'loss': loss.item()})
            train_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()
                optimizer.step()
                model.zero_grad()
                global_step += 1
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.local_rank == -1:
                        evaluate(args, model, tokenizer, writer)
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir): os.makedirs(output_dir)
                    model_to_save = (model.module
                                     if hasattr(model, "module") else model)
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    tokenizer.save_vocabulary(output_dir)
                    print("Saving model checkpoint to %s", output_dir)
                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))

        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
    return global_step, train_loss / global_step
Exemple #15
0
def run_pretraining(args):
    if args.parallel and args.local_rank == -1:
        run_parallel_pretraining(args)
        return

    if args.local_rank == -1:
        if args.cpu:
            print("CPU!!!")
            device = torch.device("cpu")
        else:
            device = torch.device("cuda")
        num_workers = 1
        worker_index = 0
    else:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        device = torch.device("cuda", args.local_rank)
        num_workers = torch.distributed.get_world_size()
        worker_index = torch.distributed.get_rank()

    if args.local_rank not in (-1, 0):
        logging.getLogger().setLevel(logging.WARN)

    logger.info(
        "Starting pretraining with the following arguments: %s", json.dumps(vars(args), indent=2, sort_keys=True)
    )

    # if args.multilingual:
    #     dataset_dir_list = args.dataset_dir.split(",")
    #     dataset_list = [MedMentionsPretrainingDataset(d) for d in dataset_dir_list]
    # else:
    dataset_list = [MedMentionsPretrainingDataset(args.dataset_dir)]

    bert_config = AutoConfig.from_pretrained(args.bert_model_name)

    dataset_size = sum([len(d) for d in dataset_list])
    num_train_steps_per_epoch = math.ceil(dataset_size / args.batch_size)
    num_train_steps = math.ceil(dataset_size / args.batch_size * args.num_epochs)
    print("The Number of Training Steps is: ", num_train_steps)
    train_batch_size = int(args.batch_size / args.gradient_accumulation_steps / num_workers)

    entity_vocab = dataset_list[0].entity_vocab
    config = LukeConfig(
        entity_vocab_size=entity_vocab.size,
        bert_model_name=args.bert_model_name,
        entity_emb_size=args.entity_emb_size,
        **bert_config.to_dict(),
    )
    model = LukePretrainingModel(config)

    global_step = args.global_step

    batch_generator_args = dict(
        batch_size=train_batch_size,
        masked_lm_prob=args.masked_lm_prob,
        masked_entity_prob=args.masked_entity_prob,
        whole_word_masking=args.whole_word_masking,
        unmasked_word_prob=args.unmasked_word_prob,
        random_word_prob=args.random_word_prob,
        unmasked_entity_prob=args.unmasked_entity_prob,
        random_entity_prob=args.random_entity_prob,
        mask_words_in_entity_span=args.mask_words_in_entity_span,
        num_workers=num_workers,
        worker_index=worker_index,
        skip=global_step * args.batch_size,
    )

    # if args.multilingual:
    #     data_size_list = [len(d) for d in dataset_list]
    #     batch_generator = MultilingualBatchGenerator(
    #         dataset_dir_list, data_size_list, args.sampling_smoothing, **batch_generator_args,
    #     )

    # else:
    batch_generator = LukePretrainingBatchGenerator(args.dataset_dir, **batch_generator_args)

    logger.info("Model configuration: %s", config)

    if args.fix_bert_weights:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.entity_embeddings.parameters():
            param.requires_grad = True
        for param in model.entity_predictions.parameters():
            param.requires_grad = True

    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    if args.original_adam:
        optimizer = AdamW(
            optimizer_parameters,
            lr=args.learning_rate,
            betas=(args.adam_b1, args.adam_b2),
            eps=args.adam_eps,
        )        
    else:
        optimizer = LukeAdamW(
            optimizer_parameters,
            lr=args.learning_rate,
            betas=(args.adam_b1, args.adam_b2),
            eps=args.adam_eps,
            grad_avg_device=torch.device("cpu") if args.grad_avg_on_cpu else device,
        )

    if args.fp16:
        from apex import amp

        if args.fp16_opt_level == "O2":
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.fp16_opt_level,
                master_weights=args.fp16_master_weights,
                min_loss_scale=args.fp16_min_loss_scale,
                max_loss_scale=args.fp16_max_loss_scale,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.fp16_opt_level,
                min_loss_scale=args.fp16_min_loss_scale,
                max_loss_scale=args.fp16_max_loss_scale,
            )

    if args.model_file is None:
        bert_model = AutoModelForPreTraining.from_pretrained(args.bert_model_name)
        bert_state_dict = bert_model.state_dict()
        model.load_bert_weights(bert_state_dict)

    else:
        model_state_dict = torch.load(args.model_file, map_location="cpu")
        model.load_state_dict(model_state_dict, strict=False)

    if args.optimizer_file is not None:
        optimizer.load_state_dict(torch.load(args.optimizer_file, map_location="cpu"))

    if args.amp_file is not None:
        amp.load_state_dict(torch.load(args.amp_file, map_location="cpu"))

    if args.lr_schedule == "warmup_constant":
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == "warmup_linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_steps
        )
        print(f"Scheduler data: Warmup steps: {args.warmup_steps}; total training steps: {num_train_steps}")
    else:
        raise RuntimeError(f"Invalid scheduler: {args.lr_schedule}")

    if args.scheduler_file is not None:
        scheduler.load_state_dict(torch.load(args.scheduler_file, map_location="cpu"))

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    model.train()

    if args.local_rank == -1 or worker_index == 0:
        entity_vocab.save(os.path.join(args.output_dir, ENTITY_VOCAB_FILE))
        metadata = dict(
            model_config=config.to_dict(),
            max_seq_length=dataset_list[0].max_seq_length,
            max_entity_length=dataset_list[0].max_entity_length,
            max_mention_length=dataset_list[0].max_mention_length,
            arguments=vars(args),
        )
        with open(os.path.join(args.output_dir, "metadata.json"), "w") as metadata_file:
            json.dump(metadata, metadata_file, indent=2, sort_keys=True)

    def save_model(model, suffix):
        if args.local_rank != -1:
            model = model.module

        model_file = f"model_{suffix}.bin"
        torch.save(model.state_dict(), os.path.join(args.output_dir, model_file))
        optimizer_file = f"optimizer_{suffix}.bin"
        torch.save(optimizer.state_dict(), os.path.join(args.output_dir, optimizer_file))
        scheduler_file = f"scheduler_{suffix}.bin"
        torch.save(scheduler.state_dict(), os.path.join(args.output_dir, scheduler_file))
        metadata = dict(
            global_step=global_step, model_file=model_file, optimizer_file=optimizer_file, scheduler_file=scheduler_file
        )
        if args.fp16:
            amp_file = f"amp_{suffix}.bin"
            torch.save(amp.state_dict(), os.path.join(args.output_dir, amp_file))
            metadata["amp_file"] = amp_file
        with open(os.path.join(args.output_dir, f"metadata_{suffix}.json"), "w") as f:
            json.dump(metadata, f, indent=2, sort_keys=True)

    if args.local_rank == -1 or worker_index == 0:
        summary_writer = SummaryWriter(args.log_dir)
        pbar = tqdm(total=num_train_steps, initial=global_step)

    tr_loss = 0
    accumulation_count = 0
    results = []
    prev_error = False
    prev_step_time = time.time()
    prev_save_time = time.time()

    for batch in batch_generator.generate_batches():
        try:
            batch = {k: torch.from_numpy(v).to(device) for k, v in batch.items()}
            result = model(**batch)
            loss = result["loss"]
            result = {k: v.to("cpu").detach().numpy() for k, v in result.items()}

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            def maybe_no_sync():
                if (
                    hasattr(model, "no_sync")
                    and num_workers > 1
                    and accumulation_count + 1 != args.gradient_accumulation_steps
                ):
                    return model.no_sync()
                else:
                    return contextlib.ExitStack()

            with maybe_no_sync():
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

        except RuntimeError:
            if prev_error:
                logger.exception("Consecutive errors have been observed. Exiting...")
                raise
            logger.exception("An unexpected error has occurred. Skipping a batch...")
            prev_error = True
            loss = None
            torch.cuda.empty_cache()
            continue

        accumulation_count += 1
        prev_error = False
        tr_loss += loss.item()
        loss = None
        results.append(result)

        if accumulation_count == args.gradient_accumulation_steps:
            if args.max_grad_norm != 0.0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            accumulation_count = 0

            summary = {}
            # line used to be, changed due to backwards compat but it should've worked? 
            # summary["learning_rate"] = max(scheduler.get_lr())
            summary["learning_rate"] = max(scheduler.get_lr())
            summary["loss"] = tr_loss
            tr_loss = 0

            current_time = time.time()
            summary["batch_run_time"] = current_time - prev_step_time
            prev_step_time = current_time

            for name in ("masked_lm", "masked_entity"):
                try:
                    summary[name + "_loss"] = np.concatenate([r[name + "_loss"].flatten() for r in results]).mean()
                    correct = np.concatenate([r[name + "_correct"].flatten() for r in results]).sum()
                    total = np.concatenate([r[name + "_total"].flatten() for r in results]).sum()
                    if total > 0:
                        summary[name + "_acc"] = correct / total
                except KeyError:
                    continue

            results = []

            if args.local_rank == -1 or worker_index == 0:
                for (name, value) in summary.items():
                    summary_writer.add_scalar(name, value, global_step)
                desc = (
                    f"epoch: {int(global_step / num_train_steps_per_epoch)} "
                    f'loss: {summary["loss"]:.4f} '
                    f'time: {datetime.datetime.now().strftime("%H:%M:%S")}'
                )
                pbar.set_description(desc)
                pbar.update()

            global_step += 1

            if args.local_rank == -1 or worker_index == 0:
                if global_step == num_train_steps:
                    # save the final model
                    save_model(model, f"epoch{args.num_epochs}")
                    time.sleep(60)
                elif global_step % num_train_steps_per_epoch == 0:
                    # save the model at each epoch
                    epoch = int(global_step / num_train_steps_per_epoch)
                    save_model(model, f"epoch{epoch}")
                if args.save_interval_sec and time.time() - prev_save_time > args.save_interval_sec:
                    save_model(model, f"step{global_step:07}")
                    prev_save_time = time.time()
                if args.save_interval_steps and global_step % args.save_interval_steps == 0:
                    save_model(model, f"step{global_step}")

            if global_step == num_train_steps:
                break

    if args.local_rank == -1 or worker_index == 0:
        summary_writer.close()
Exemple #16
0
def run_training(args, ls):
    ls.print('Training started: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    # Misc setup
    os.makedirs(args.model_dir, exist_ok=True)
    assert len(args.cnn_filters)%2 == 0
    args.cnn_filters = list(zip(args.cnn_filters[:-1:2], args.cnn_filters[1::2]))
    # Load the vocabs
    vocabs = get_vocabs(os.path.join(args.model_dir, args.vocab_dir))
    bert_tokenizer = None
    if args.with_bert:
        bert_tokenizer = BertEncoderTokenizer.from_pretrained(args.bert_path, do_lower_case=False)
        vocabs['bert_tokenizer'] = bert_tokenizer
    for name in vocabs:
        if name == 'bert_tokenizer':
            continue
        ls.print('Vocab %-20s  size %5d  coverage %.3f' % (name, vocabs[name].size, vocabs[name].coverage))
    # Setup BERT encoder
    bert_encoder = None
    if args.with_bert:
        bert_encoder = BertEncoder.from_pretrained(args.bert_path)
        for p in bert_encoder.parameters():
            p.requires_grad = False
    # Device and random setup
    torch.manual_seed(19940117)
    torch.cuda.manual_seed_all(19940117)
    random.seed(19940117)
    device = torch.device(args.device)
    # Create the model
    ls.print('Setting up the model')
    model = Parser(vocabs,
            args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim,
            args.concept_char_dim, args.concept_dim,
            args.cnn_filters, args.char2word_dim, args.char2concept_dim,
            args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout,
            args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim,
            device, args.pretrained_file, bert_encoder,)
    model = model.to(device)
    # Optimizer and weight decay params
    weight_decay_params = []
    no_weight_decay_params = []
    for name, param in model.named_parameters():
        if name.endswith('bias') or 'layer_norm' in name:
            no_weight_decay_params.append(param)
        else:
            weight_decay_params.append(param)
    grouped_params = [{'params':weight_decay_params, 'weight_decay':1e-4},
                        {'params':no_weight_decay_params, 'weight_decay':0.}]
    optimizer = AdamW(grouped_params, 1., betas=(0.9, 0.999), eps=1e-6)
    # Re-load an existing model if requested
    used_batches = 0
    batches_acm = 0
    if args.resume_ckpt:
        ls.print('Resuming from checkpoint', args.resume_ckpt)
        ckpt = torch.load(args.resume_ckpt)
        model.load_state_dict(ckpt['model'])
        if ckpt.get('optimizer', {}):
            optimizer.load_state_dict(ckpt['optimizer'])
        else:
            ls.print('No optimizer state saved in checkpoint, using default initial optimizer')
        batches_acm = ckpt['batches_acm']
        start_epoch = ckpt['epoch'] + 1
        del ckpt
    else:
        start_epoch = 1     # don't start at 0
    # Load data
    ls.print('Loading training data')
    train_data = DataLoader(vocabs, args.train_data, args.train_batch_size, for_train=True)
    train_data.set_unk_rate(args.unk_rate)
    # Train
    ls.print('Training')
    epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0
    for epoch in range(start_epoch, args.epochs+1):
        st = time.time()
        for batch in train_data:
            model.train()
            batch = move_to_device(batch, model.device)
            concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch)
            loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update
            loss_value = loss.item()
            concept_loss_value = concept_loss.item()
            arc_loss_value = arc_loss.item()
            rel_loss_value = rel_loss.item()
            loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value
            concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value
            arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value
            rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value
            loss.backward()
            used_batches += 1
            if not (used_batches % args.batches_per_update == -1 % args.batches_per_update):
                continue
            batches_acm += 1
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            lr = update_lr(optimizer, args.lr_scale, args.embed_dim, batches_acm, args.warmup_steps)
            optimizer.step()
            optimizer.zero_grad()
        # Summary at the end of the epoch
        dur = time.time() - st
        ls.print('Epoch %4d, Batch %5d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f, duration %.1f seconds' %
                    (epoch, batches_acm, lr, concept_loss_avg, arc_loss_avg, rel_loss_avg, dur))
        # Evaluate and save the data every so often
        if (epoch>args.skip_evals or args.resume_ckpt is not None) and epoch % args.eval_every == 0:
            model.eval()
            ls.print('Evaluating and saving the model')
            fname = '%s/epoch%d.pt'%(args.model_dir, epoch)
            optim = optimizer.state_dict() if args.save_optimizer else {}
            torch.save({'args':vars(args), 'model':model.state_dict(), 'batches_acm': batches_acm,
                        'optimizer': optim, 'epoch':epoch}, fname)
            try:
                out_fn = 'epoch%d.pt.dev_generated' % (epoch)
                inference = Inference.build_from_model(model, vocabs)
                f_score, ctr = inference.reparse_annotated_file('.', args.dev_data, args.model_dir, out_fn,
                        print_summary=False)
                ls.print('Smatch F: %.3f.  Wrote %d AMR graphs to %s' % \
                        (f_score, ctr, os.path.join(args.model_dir, out_fn)))
            except:
                ls.print('Exception during generation')
                traceback.print_exc()
            model.train()
    # End time-stamp
    ls.print('Training finished: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
Exemple #17
0
class Trainer():
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)

    def train_test_split(self, list_phrases, test_size=0.1):
        list_phrases = list_phrases
        train_idx = int(len(list_phrases) * (1 - test_size))
        list_phrases_train = list_phrases[:train_idx]
        list_phrases_valid = list_phrases[train_idx:]
        return list_phrases_train, list_phrases_valid

    def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True):
        dataset = AutoCorrectDataset(list_ngrams_np,
                                     transform_noise=synthesizer,
                                     vocab=vocab,
                                     maxlen=MAXLEN)

        shuffle = True if is_train else False
        gen = DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=shuffle,
                         drop_last=False)

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        src, tgt = batch['src'], batch['tgt']
        src, tgt = src.transpose(1, 0), tgt.transpose(
            1, 0)  # batch x src_len -> src_len x batch

        outputs = self.model(
            src, tgt)  # src : src_len x B, outpus : B x tgt_len x vocab

        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  # flatten(0, 1)

        tgt_output = tgt.transpose(0, 1).reshape(
            -1)  # flatten()   # tgt: tgt_len xB , need convert to B x tgt_len

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def train(self):
        print("Begin training from iter: ", self.iter)
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = -1

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.iter % self.valid_every == 0:
                val_loss, preds, actuals, inp_sents = self.validate()
                acc_full_seq, acc_per_char, cer = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, cer)
                print(info)
                print("--- Sentence predict ---")
                for pred, inp, label in zip(preds, inp_sents, actuals):
                    infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format(
                        pred, inp, label)
                    print(infor_predict)
                    self.logger.log(infor_predict)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
                self.save_checkpoint(self.checkpoint)

    def validate(self):
        self.model.eval()

        total_loss = []
        max_step = self.metrics / self.batch_size
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                src, tgt = batch['src'], batch['tgt']
                src, tgt = src.transpose(1, 0), tgt.transpose(1, 0)

                outputs = self.model(src, tgt, 0)  # turn off teaching force

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                preds, actuals, inp_sents, probs = self.predict(5)

                del outputs
                del loss
                if step > max_step:
                    break

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss, preds[:3], actuals[:3], inp_sents[:3]

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        inp_sents = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['src'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['src'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt'].tolist())
            inp_sent = self.vocab.batch_decode(batch['src'].tolist())

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            inp_sents.extend(inp_sent)

            if sample is not None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, inp_sents, prob

    def precision(self, sample=None):

        pred_sents, actual_sents, _, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')
        cer = compute_accuracy(actual_sents, pred_sents, mode='CER')

        return acc_full_seq, acc_per_char, cer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files, probs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                n += 1
                if n >= sample:
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'scheduler': self.scheduler.state_dict()
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):

        src = batch['src'].to(self.device, non_blocking=True)
        tgt = batch['tgt'].to(self.device, non_blocking=True)

        batch = {'src': src, 'tgt': tgt}

        return batch
Exemple #18
0
def train_loop(new_data, old_data, stats=None):
    ## prep dataloaders ##
    X, y = new_data['X'], new_data['y']
    dataloader_new = DataLoader(list(zip(X, y)), batch_size=1, shuffle=True)
    dataloader_old = DataLoader(old_data, batch_size=1, shuffle=True)
    del X, y

    ## optimizer and scheduler ##
    # calculate total steps
    opts.gradient_accumulation_steps, opts.num_train_epochs = 64, 1
    t_total = len(dataloader_old
                  ) // opts.gradient_accumulation_steps * opts.num_train_epochs

    ## set up optimizers and schedulers ##
    with torch.no_grad():
        fast_group = flatten([[p[act_tok], p[start_tok], p[p1_tok], p[p2_tok]]
                              for n, p in model.named_parameters()
                              if n == 'transformer.wte.weight'
                              ])  #['transformer.wte.weight']
        freeze_group = [
            p[:start_tok] for n, p in model.named_parameters()
            if n == 'transformer.wte.weight'
        ]  #['transformer.wte.weight']
        slow_group = [
            p for n, p in model.named_parameters()
            if n == 'transformer.wpe.weight'
        ]
        normal_group = [
            p for n, p in model.named_parameters()
            if n not in ('transformer.wte.weight', 'transformer.wpe.weight')
        ]
    # different learn rates for different param groups
    optimizer_grouped_parameters = [{
        "params": fast_group,
        'lr': 5e-4
    }, {
        "params": freeze_group,
        'lr': 1e-8
    }, {
        "params": slow_group,
        'lr': 1e-6
    }, {
        "params": normal_group,
        'lr': opts.lr
    }]
    optimizer = AdamW(optimizer_grouped_parameters, lr=opts.lr, eps=opts.eps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=opts.warmup_steps,
        num_training_steps=t_total)
    # loading optimizer settings
    if (opts.model_name_or_path and os.path.isfile(
            os.path.join(opts.model_name_or_path, "train_optimizer.pt"))
            and os.path.isfile(
                os.path.join(opts.model_name_or_path, "train_scheduler.pt"))):
        # load optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(
                os.path.join(opts.model_name_or_path, "train_optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(
                os.path.join(opts.model_name_or_path, "train_scheduler.pt")))
    # track stats
    if stats is not None:
        global_step = max(stats.keys())
        epochs_trained = global_step // (len(dataloader_old) //
                                         opts.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(dataloader_old) // opts.gradient_accumulation_steps)
        print("Resuming Training ... ")
    else:
        stats = {}
        global_step, epochs_trained, steps_trained_in_current_epoch = 0, 0, 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_old, logging_loss_old = 0.0, 0.0
    model.zero_grad()
    print("Re-sizing model ... ")
    model.resize_token_embeddings(len(tokenizer))
    # training mode
    model.train()
    data_iter_new = iter(dataloader_new)
    data_iter_old = iter(dataloader_old)
    for epoch in range(epochs_trained, opts.num_train_epochs):
        for step in range(len(dataloader_old)):
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                batch = data_iter_old.next()
                continue
            ### new data step ###
            try:
                batch = data_iter_new.next()
            except:
                X, y = new_data['X'], new_data['y']
                dataloader_new = DataLoader(list(zip(X, y)),
                                            batch_size=1,
                                            shuffle=True)
                del X, y
                data_iter_new = iter(dataloader_new)
                batch = data_iter_new.next()
            new_loss = fit_on_batch(batch)
            del batch
            tr_loss += new_loss.item()

            ## old data step ###
            try:
                batch = data_iter_old.next()
            except:
                data_iter_old = iter(dataloader_old)
                batch = data_iter_old.next()
            old_loss = fit_on_batch(batch)
            del batch
            tr_loss_old += old_loss.item()

            # gradient accumulation
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opts.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # reporting
                if global_step % opts.logging_steps == 0:
                    stats[global_step] = {
                        'persona_loss':
                        (tr_loss - logging_loss) / opts.logging_steps,
                        'ctrl_loss':
                        (tr_loss_old - logging_loss_old) / opts.logging_steps,
                        'train_lr': scheduler.get_last_lr()[-1]
                    }
                    logging_loss = tr_loss
                    logging_loss_old = tr_loss_old

                    print(
                        'Epoch: %d | Iter: [%d/%d] | new_loss: %.3f | old_loss: %.3f | lr: %s '
                        % (epoch, step, len(dataloader_old),
                           stats[global_step]['persona_loss'],
                           stats[global_step]['ctrl_loss'],
                           str(stats[global_step]['train_lr'])))

                if global_step % opts.save_steps == 0:
                    print("Saving stuff ... ")
                    checkpoint(model,
                               tokenizer,
                               optimizer,
                               scheduler,
                               stats,
                               title="train_")
                    plot_losses(stats, title='persona_loss')
                    plot_losses(stats, title='ctrl_loss')
                    plot_losses(stats, title='train_lr')
                    print("Done.")

    return stats
Exemple #19
0
def optim_config(args: dict, model):
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    bert_param_optimizer = list(model.bert.named_parameters())
    crf_param_optimizer = list(model.crf.named_parameters())
    linear_param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in bert_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args['weight_decay'],
        'lr':
        args['learning_rate']
    }, {
        'params': [
            p for n, p in bert_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        args['learning_rate']
    }, {
        'params': [
            p for n, p in crf_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args['weight_decay'],
        'lr':
        args['crf_learning_rate']
    }, {
        'params':
        [p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0,
        'lr':
        args['crf_learning_rate']
    }, {
        'params': [
            p for n, p in linear_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args['weight_decay'],
        'lr':
        args['crf_learning_rate']
    }, {
        'params': [
            p for n, p in linear_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        args['crf_learning_rate']
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args['learning_rate'],
                      eps=args['adam_epsilon'])
    if os.path.isfile(os.path.join(args['model_name_or_path'],
                                   "optimizer.pt")):
        # Load in optimizer states
        optimizer.load_state_dict(
            torch.load(os.path.join(args['model_name_or_path'],
                                    "optimizer.pt")))
    return optimizer
Exemple #20
0
class Seq2seqKpGen(object):
    """High level model that handles intializing the underlying network
    architecture, saving, updating examples, and predicting examples.
    """

    # --------------------------------------------------------------------------
    # Initialization
    # --------------------------------------------------------------------------

    def __init__(self, args, word_dict, state_dict=None):
        # Book-keeping.
        self.args = args
        self.word_dict = word_dict
        self.args.vocab_size = len(word_dict)
        self.updates = 0

        self.network = Sequence2Sequence(self.args, self.word_dict)
        if state_dict:
            self.network.load_state_dict(state_dict)

    def activate_fp16(self):
        if not hasattr(self, 'optimizer'):
            self.network.half()  # for testing only
            return
        try:
            global amp
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # https://github.com/NVIDIA/apex/issues/227
        assert self.optimizer is not None
        self.network, self.optimizer = amp.initialize(self.network,
                                                      self.optimizer,
                                                      opt_level=self.args.fp16_opt_level)

    def init_optimizer(self, optim_state=None, sched_state=None):
        def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
            def lr_lambda(current_step: int):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1.0, num_warmup_steps))
                return 1.0

            return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.network.named_parameters()
                           if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {"params": [p for n, p in self.network.named_parameters()
                        if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
        self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps)

        if optim_state:
            self.optimizer.load_state_dict(optim_state)
            if self.args.device.type == 'cuda':
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(self.args.device)
        if sched_state:
            self.scheduler.load_state_dict(sched_state)

    # --------------------------------------------------------------------------
    # Learning
    # --------------------------------------------------------------------------

    def update(self, ex):
        """Forward a batch of examples; step the optimizer to update weights."""
        if not self.optimizer:
            raise RuntimeError('No optimizer set.')

        # Train mode
        self.network.train()

        source_map, alignment = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)
        target_rep = ex['target_rep'].to(self.args.device)
        target_len = ex['target_len'].to(self.args.device)

        # Run forward
        ml_loss, loss_per_token = self.network(source=source_rep,
                                               source_len=source_len,
                                               target=target_rep,
                                               target_len=target_len,
                                               src_map=source_map,
                                               alignment=alignment)

        loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss
        if self.args.fp16:
            global amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping)
        else:
            loss.backward()
            clip_grad_norm_(self.network.parameters(), self.args.grad_clipping)

        self.updates += 1
        self.optimizer.step()
        self.scheduler.step()  # Update learning rate schedule
        self.optimizer.zero_grad()

        loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token
        loss_per_token = loss_per_token.item()
        loss_per_token = 10 if loss_per_token > 10 else loss_per_token
        perplexity = math.exp(loss_per_token)

        return {
            'ml_loss': loss.item(),
            'perplexity': perplexity
        }

    # --------------------------------------------------------------------------
    # Prediction
    # --------------------------------------------------------------------------

    def predict(self, ex, replace_unk=False):
        """Forward a batch of examples only to get predictions.
        Args:
            ex: the batch examples
            replace_unk: replace `unk` tokens while generating predictions
            src_raw: raw source (passage); required to replace `unk` term
        Output:
            predictions: #batch predicted sequences
        """

        def convert_text_to_string(text):
            """ Converts a sequence of tokens (string) in a single string. """
            out_string = text.replace(" ##", "").strip()
            return out_string

        self.network.eval()

        source_map, alignment = None, None
        blank, fill = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)
            blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab'])

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)

        decoder_out = self.network(source=source_rep,
                                   source_len=source_len,
                                   target=None,
                                   target_len=None,
                                   src_map=source_map,
                                   alignment=alignment,
                                   max_len=self.args.max_tgt_len,
                                   tgt_dict=self.word_dict,
                                   blank=blank, fill=fill,
                                   source_vocab=ex['src_vocab'])

        dec_probs = torch.exp(decoder_out['dec_log_probs'])
        predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs,
                                             self.word_dict, ex['src_vocab'])
        if replace_unk:
            for i in range(len(predictions)):
                enc_dec_attn = decoder_out['attentions'][i]
                if self.args.model_type == 'transformer':
                    # tgt_len x num_heads x src_len
                    assert enc_dec_attn.dim() == 3
                    enc_dec_attn = enc_dec_attn.mean(1)
                predictions[i] = replace_unknown(predictions[i], enc_dec_attn,
                                                 src_raw=ex['source'][i].tokens)

        for bidx in range(ex['batch_size']):
            for i in range(len(predictions[bidx])):
                if predictions[bidx][i] == constants.KP_SEP:
                    scores[bidx][i] = constants.KP_SEP
                elif predictions[bidx][i] == constants.PRESENT_EOS:
                    scores[bidx][i] = constants.PRESENT_EOS
                else:
                    assert isinstance(scores[bidx][i], float)
                    scores[bidx][i] = str(scores[bidx][i])

        predictions = [' '.join(item) for item in predictions]
        scores = [' '.join(item) for item in scores]

        present_kps = []
        absent_kps = []
        present_kp_scores = []
        absent_kp_scores = []
        for bidx in range(ex['batch_size']):
            keyphrases = predictions[bidx].split(constants.PRESENT_EOS)
            kp_scores = scores[bidx].split(constants.PRESENT_EOS)
            pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1])
            pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1])
            akps = keyphrases[-1]
            akp_scores = kp_scores[-1]

            pre_kps = []
            pre_kp_scores = []
            for pkp, pkp_s in zip(pkps.split(constants.KP_SEP),
                                  pkp_scores.split(constants.KP_SEP)):
                pkp = pkp.strip()
                if pkp:
                    pre_kps.append(convert_text_to_string(pkp))
                    t_scores = [float(i) for i in pkp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    pre_kp_scores.append(_score)

            present_kps.append(pre_kps)
            present_kp_scores.append(pre_kp_scores)

            abs_kps = []
            abs_kp_scores = []
            for akp, akp_s in zip(akps.split(constants.KP_SEP),
                                  akp_scores.split(constants.KP_SEP)):
                akp = akp.strip()
                if akp:
                    abs_kps.append(convert_text_to_string(akp))
                    t_scores = [float(i) for i in akp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    abs_kp_scores.append(_score)

            absent_kps.append(abs_kps)
            absent_kp_scores.append(abs_kp_scores)

        return {
            'present_kps': present_kps,
            'absent_kps': absent_kps,
            'present_kp_scores': present_kp_scores,
            'absent_kp_scores': absent_kp_scores
        }

    # --------------------------------------------------------------------------
    # Saving and loading
    # --------------------------------------------------------------------------

    def save(self, filename):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        state_dict = copy.copy(network.state_dict())
        params = {
            'state_dict': state_dict,
            'word_dict': self.word_dict,
            'args': self.args,
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    def checkpoint(self, filename, epoch):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        params = {
            'state_dict': network.state_dict(),
            'word_dict': self.word_dict,
            'args': self.args,
            'epoch': epoch,
            'updates': self.updates,
            'optim_dict': self.optimizer.state_dict(),
            'sched_dict': self.scheduler.state_dict(),
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    @staticmethod
    def load(filename, new_args=None):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        args = saved_params['args']
        if new_args:
            args = override_model_args(args, new_args)
        return Seq2seqKpGen(args, word_dict, state_dict)

    @staticmethod
    def load_checkpoint(filename):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        epoch = saved_params['epoch']
        updates = saved_params['updates']
        optim_dict = saved_params['optim_dict']
        sched_dict = saved_params['sched_dict']
        args = saved_params['args']
        model = Seq2seqKpGen(args, word_dict, state_dict)
        model.updates = updates
        model.init_optimizer(optim_dict, sched_dict)
        return model, epoch

    # --------------------------------------------------------------------------
    # Runtime
    # --------------------------------------------------------------------------

    def to(self, device):
        self.network = self.network.to(device)

    def parallelize(self):
        self.network = torch.nn.DataParallel(self.network)
Exemple #21
0
class Trainer():
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1
            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start
            start = time.time()

            # LOSS
            loss = self.step(batch)
            total_loss += loss
            self.train_losses.append((self.iter, loss))

            total_gpu_time += time.time() - start

            if self.iter % self.print_every == 0:

                info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)
                lastest_loss = total_loss / self.print_every
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                self.logger.info(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_time = time.time()
                val_loss = self.validate()
                acc_full_seq, acc_per_char, wer = self.precision(self.metrics)

                self.logger.info("Iter: {:06d}, start validating".format(
                    self.iter))
                info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, wer,
                    time.time() - val_time)
                self.logger.info(info)

                if acc_full_seq > self.best_acc:
                    self.save_weights(self.tensorboard_dir + "/best.pt")
                    self.best_acc = acc_full_seq

                self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format(
                    self.iter, self.best_acc))

                filename = 'last.pt'
                filepath = os.path.join(self.tensorboard_dir, filename)
                self.logger.info("Save checkpoint %s" % filename)
                self.save_checkpoint(filepath)

                log_loss = {'train loss': lastest_loss, 'val loss': val_loss}
                self.writer.add_scalars('Loss', log_loss, self.iter)
                self.writer.add_scalar('WER', wer, self.iter)

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                if self.model.seq_modeling == 'crnn':
                    length = batch['labels_len']
                    preds_size = torch.autograd.Variable(
                        torch.IntTensor([outputs.size(0)] * self.batch_size))
                    loss = self.criterion(outputs, tgt_output, preds_size,
                                          length)
                else:
                    outputs = outputs.flatten(0, 1)
                    tgt_output = tgt_output.flatten()
                    loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        probs_sents = []
        imgs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            imgs_sents.extend(batch['img'])
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)

            # Visualize in tensorboard
            if idx == 0:
                try:
                    num_samples = self.config['monitor']['num_samples']
                    fig = plt.figure(figsize=(12, 15))
                    imgs_samples = imgs_sents[:num_samples]
                    preds_samples = pred_sents[:num_samples]
                    actuals_samples = actual_sents[:num_samples]
                    probs_samples = probs_sents[:num_samples]
                    for id_img in range(len(imgs_samples)):
                        img = imgs_samples[id_img]
                        img = img.permute(1, 2, 0)
                        img = img.cpu().detach().numpy()
                        ax = fig.add_subplot(num_samples,
                                             1,
                                             id_img + 1,
                                             xticks=[],
                                             yticks=[])
                        plt.imshow(img)
                        ax.set_title(
                            "LB: {} \n Pred: {:.4f}-{}".format(
                                actuals_samples[id_img], probs_samples[id_img],
                                preds_samples[id_img]),
                            color=('green' if actuals_samples[id_img]
                                   == preds_samples[id_img] else 'red'),
                            fontdict={
                                'fontsize': 18,
                                'fontweight': 'medium'
                            })

                    self.writer.add_figure('predictions vs. actuals',
                                           fig,
                                           global_step=self.iter)
                except Exception as error:
                    print(error)
                    continue

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files, probs_sents, imgs_sents

    def precision(self, sample=None, measure_time=True):
        t1 = time.time()
        pred_sents, actual_sents, _, _, _ = self.predict(sample=sample)
        time_predict = time.time() - t1

        sensitive_case = self.config['predictor']['sensitive_case']
        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='per_char')
        wer = compute_accuracy(actual_sents,
                               pred_sents,
                               sensitive_case,
                               mode='wer')

        if measure_time:
            print("Time: {:.4f}".format(time_predict / len(actual_sents)))
        return acc_full_seq, acc_per_char, wer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16,
                             save_fig=False):

        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]
            imgs = [imgs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}
        ncols = 5
        nrows = int(math.ceil(len(img_files) / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15))

        for vis_idx in range(0, len(img_files)):
            row = vis_idx // ncols
            col = vis_idx % ncols

            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx]
            img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy()

            ax[row, col].imshow(img)
            ax[row, col].set_title(
                "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format(
                    pred_sent, actual_sent, prob),
                fontname=fontname,
                color='r' if pred_sent != actual_sent else 'g')
            ax[row, col].get_xaxis().set_ticks([])
            ax[row, col].get_yaxis().set_ticks([])

        plt.subplots_adjust()
        if save_fig:
            fig.savefig('vis_prediction.png')
        plt.show()

    def log_prediction(self, sample=16, csv_file='model.csv'):
        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)
        save_predictions(csv_file, pred_sents, actual_sents, img_files)

    def vis_data(self, sample=20):

        ncols = 5
        nrows = int(math.ceil(sample / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12))

        num_plots = 0
        for idx, batch in enumerate(self.train_gen):
            for vis_idx in range(self.batch_size):
                row = num_plots // ncols
                col = num_plots % ncols

                img = batch['img'][vis_idx].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(
                    batch['tgt_input'].T[vis_idx].tolist())

                ax[row, col].imshow(img)
                ax[row, col].set_title("Label: {: <2}".format(sent),
                                       fontsize=16,
                                       color='g')

                ax[row, col].get_xaxis().set_ticks([])
                ax[row, col].get_yaxis().set_ticks([])

                num_plots += 1
                if num_plots >= sample:
                    plt.subplots_adjust()
                    fig.savefig('vis_dataset.png')
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']
        self.train_losses = checkpoint['train_losses']
        if self.scheduler is not None:
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_acc = checkpoint['best_acc']

    def save_checkpoint(self, filename):
        state = {
            'iter':
            self.iter,
            'state_dict':
            self.model.state_dict(),
            'optimizer':
            self.optimizer.state_dict(),
            'train_losses':
            self.train_losses,
            'scheduler':
            None if self.scheduler is None else self.scheduler.state_dict(),
            'best_acc':
            self.best_acc
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))
        if self.is_checkpoint(state_dict):
            self.model.load_state_dict(state_dict['state_dict'])
        else:

            for name, param in self.model.named_parameters():
                if name not in state_dict:
                    print('{} not found'.format(name))
                elif state_dict[name].shape != param.shape:
                    print('{} missmatching shape, required {} but found {}'.
                          format(name, param.shape, state_dict[name].shape))
                    del state_dict[name]
            self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def is_checkpoint(self, checkpoint):
        try:
            checkpoint['state_dict']
        except:
            return False
        else:
            return True

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames'],
            'labels_len': batch['labels_len']
        }

        return batch

    def data_gen(self,
                 lmdb_paths,
                 data_root,
                 annotation,
                 masked_language_model=True,
                 transform=None,
                 is_train=False):
        datasets = []
        for lmdb_path in lmdb_paths:
            dataset = OCRDataset(
                lmdb_path=lmdb_path,
                root_dir=data_root,
                annotation_path=annotation,
                vocab=self.vocab,
                transform=transform,
                image_height=self.config['dataset']['image_height'],
                image_min_width=self.config['dataset']['image_min_width'],
                image_max_width=self.config['dataset']['image_max_width'],
                separate=self.config['dataset']['separate'],
                batch_size=self.batch_size,
                is_padding=self.is_padding)
            datasets.append(dataset)
        if len(self.train_lmdb) > 1:
            dataset = torch.utils.data.ConcatDataset(datasets)

        if self.is_padding:
            sampler = None
        else:
            sampler = ClusterRandomSampler(dataset, self.batch_size, True)

        collate_fn = Collator(masked_language_model)

        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=is_train,
                         drop_last=self.model.seq_modeling == 'crnn',
                         **self.config['dataloader'])

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

        if self.model.seq_modeling == 'crnn':
            length = batch['labels_len']
            preds_size = torch.autograd.Variable(
                torch.IntTensor([outputs.size(0)] * self.batch_size))
            loss = self.criterion(outputs, tgt_output, preds_size, length)
        else:
            outputs = outputs.view(
                -1, outputs.size(2))  # flatten(0, 1)    # B*S x N_class
            tgt_output = tgt_output.view(-1)  # flatten()    # B*S
            loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()

        if not self.is_finetuning:
            self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def gen_pseudo_labels(self, outfile=None):
        pred_sents = []
        img_files = []
        probs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            pred_sents.extend(pred_sent)
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)
        assert len(pred_sents) == len(img_files) and len(img_files) == len(
            probs_sents)
        with open(outfile, 'w', encoding='utf-8') as f:
            for anno in zip(img_files, pred_sents, probs_sents):
                f.write('||||'.join([anno[0], anno[1],
                                     str(float(anno[2]))]) + '\n')
def main():
    args = parseArguments()

    os.makedirs(args.modelDir, exist_ok=True)
    checkpointDir = os.path.join(args.modelDir, 'checkpoints')
    os.makedirs(checkpointDir, exist_ok=True)

    os.makedirs(args.ensembleDir, exist_ok=True)

    with EventTimer('Preparing for dataset / dataloader'):
        trainDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                      os.path.join(args.trainImages),
                                      transform=trainingPreprocessing)
        validDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                      os.path.join(args.validImages),
                                      transform=inferencePreprocessing)

        trainDataloader = DataLoader(trainDataset,
                                     batch_size=args.batchSize,
                                     num_workers=args.numWorkers,
                                     shuffle=True)
        validDataloader = DataLoader(validDataset,
                                     batch_size=args.batchSize,
                                     num_workers=args.numWorkers,
                                     shuffle=False)

        print(f'> Training dataset:\t{len(trainDataset)}')
        print(f'> Validation dataset:\t{len(validDataset)}')

    with EventTimer(f'Load pretrained model - {args.pretrainModel}'):
        model = models.GetPretrainedModel(args.pretrainModel,
                                          fcDims=args.fcDims + [42])
        print(model)
        #torchsummary will crash under densenet, skip the summary.
        #torchsummary.summary(model, (3, 224, 224), device='cpu')

    with EventTimer(f'Train model'):
        model.cuda()

        criterion = CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2)
        scheduler = CosineAnnealingLR(optimizer,
                                      T_max=args.epochs,
                                      eta_min=1e-6)
        history = []

        if args.retrain != 0:
            checkpoint = torch.load(
                os.path.join(checkpointDir,
                             f'checkpoint-{args.retrain:03d}.pt'))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            history = checkpoint['history']

        def runEpoch(dataloader, train=False, name=''):
            # For empty validation dataloader
            if len(dataloader) == 0:
                return 0, 0

            # Enable grad
            with (torch.enable_grad() if train else torch.no_grad()):
                if train: model.train()
                else: model.eval()

                losses = []
                for img, label, imgPath in tqdm(dataloader,
                                                desc=name,
                                                ncols=80):
                    if train:
                        optimizer.zero_grad()

                    output = model(img.cuda()).cpu()
                    loss = criterion(output, label)

                    if train:
                        loss.backward()
                        optimizer.step()

                    accu = accuracy(output.data.numpy(), label.numpy())
                    losses.append((loss.item(), accu))

            return map(np.mean, zip(*losses))

        def cleanUp():
            model.eval()
            train_pred = np.zeros((trainDataloader.__len__()) * args.batchSize)
            cnt = 0
            for i, (data, label, path) in enumerate(trainDataloader):
                test_pred = model(data.cuda())
                pred = np.max(test_pred.cpu().data.numpy(), axis=1)
                train_pred[cnt:cnt + len(pred)] = pred
                cnt += len(pred)

            sorted_pred = train_pred
            sorted_pred.sort()
            threshold = sorted_pred[(len(sorted_pred) // 20)]
            data_set = [[], []]

            for i, (data, label, path) in enumerate(trainDataloader):
                test_pred = model(data.cuda())
                pred = np.max(test_pred.cpu().data.numpy(), axis=1)
                for j in range(len(pred)):
                    if pred[j] >= threshold:
                        data_set[0].append(path[j])
                        data_set[1].append(label[j])

            newDataset = ProductDataset(os.path.join(args.dataDir, 'train'),
                                        os.path.join(args.trainImages),
                                        transform=trainingPreprocessing,
                                        data=data_set)
            newDataloader = DataLoader(newDataset,
                                       batch_size=args.batchSize,
                                       num_workers=args.numWorkers,
                                       shuffle=True)

            print(
                f"{newDataloader.__len__() * args.batchSize} images remain after cleanup"
            )
            return newDataloader

        for epoch in range(args.retrain + 1, args.epochs + 1):
            with EventTimer(verbose=False) as et:
                print(f'====== Epoch {epoch:3d} / {args.epochs:3d} ======')
                trainLoss, trainAccu = runEpoch(trainDataloader,
                                                train=True,
                                                name='training  ')
                validLoss, validAccu = runEpoch(validDataloader,
                                                name='validation')

                history.append(
                    ((trainLoss, trainAccu), (validLoss, validAccu)))

                scheduler.step()
                print(
                    f'[{et.gettime():.4f}s] Training: {trainLoss:.6f} / {trainAccu:.4f} ; Validation {validLoss:.6f} / {validAccu:.4f}'
                )

            if args.cleanup and epoch % args.cleanup_epoch == 0:
                with EventTimer('Cleaning Training Set'):
                    trainDataloader = cleanUp()

            if epoch % 5 == 0:
                torch.save(
                    {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'history': history,
                    }, os.path.join(checkpointDir,
                                    f'checkpoint-{epoch:03d}.pt'))

        # save model as its coressponding name
        torch.save(model.state_dict(),
                   os.path.join(args.modelDir, 'model-weights.pt'))
        utils.pickleSave(history, os.path.join(args.modelDir, 'history.pkl'))
class Training:
    def __init__(self, model, device, config, name, fold_num, imsize):
        self.config = config
        self.epoch = 0
        self.base_dir = './models/'
        os.makedirs('./models', exist_ok=True)
        self.model = model
        self.best_loss = 10**5
        self.device = device
        self.name = name
        self.fold_num = fold_num
        self.imsize = imsize
        # optimize
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.00
        }]
        self.optimizer = AdamW(self.model.parameters(), lr=config.lr)
        self.scheduler = config.SchedulerClass(self.optimizer,
                                               **config.scheduler_params)
        # Earlystopping
        self.patience = config.patience
        # GradScaler
        self.scaler = GradScaler()

    def train_one_epoch(self, train_loader):
        self.model.train()
        showloss = Showloss()

        for step, (images, targets) in tqdm(enumerate(train_loader),
                                            total=len(train_loader)):
            self.optimizer.zero_grad()

            with autocast():
                images = torch.stack(
                    images)  # 이미지들을 합쳐 Batch 생성 (default: dim=0) [B,C,H,W]
                images = images.to(self.device).float()
                batch_size = images.shape[0]
                boxes = [
                    target['bbox'].to(self.device).float()
                    for target in targets
                ]
                labels = [
                    target['cls'].to(self.device).float() for target in targets
                ]
                img_scale = torch.tensor([
                    target['img_scale'].to(self.device).float()
                    for target in targets
                ])
                img_size = torch.tensor([
                    (self.imsize, self.imsize) for target in targets
                ]).to(self.device).float()

                # update 후로 forward는 image와 target_dict를 인자로 받음
                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels
                target_res['img_scale'] = img_scale
                target_res['img_size'] = img_size

                # pred
                output = self.model(images, target_res)
                loss = output['loss']
                showloss.update(loss.detach().item(), batch_size)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

        return showloss

    def val_one_epoch(self, val_loader):
        self.model.eval()
        showloss = Showloss()
        for step, (images, targets) in tqdm(enumerate(val_loader),
                                            total=len(val_loader)):
            with torch.no_grad():
                images = torch.stack(images)
                batch_size = images.shape[0]
                images = images.to(self.device).float()
                boxes = [
                    target['bbox'].to(self.device).float()
                    for target in targets
                ]
                labels = [
                    target['cls'].to(self.device).float() for target in targets
                ]
                img_scale = torch.tensor([
                    target['img_scale'].to(self.device).float()
                    for target in targets
                ])
                img_size = torch.tensor([
                    (self.imsize, self.imsize) for target in targets
                ]).to(self.device).float()

                target_res = {}
                target_res['bbox'] = boxes
                target_res['cls'] = labels
                target_res['img_scale'] = img_scale
                target_res['img_size'] = img_size

                # loss, _, _ = self.model(images, boxes, labels)
                output = self.model(images, target_res)
                loss = output['loss']
                showloss.update(loss.detach().item(), batch_size)

        return showloss

    def save(self, path):  # 모델 및 파라미터 저장
        self.model.eval()
        torch.save(
            {
                'model_state_dict': self.model.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'loss': self.best_loss,  # val
                'epoch': self.epoch,
            },
            path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_loss = checkpoint['best_loss']  # val
        self.epoch = checkpoint['epoch'] + 1

    def fit(self, train_loader, val_loader):
        early_stopping = EarlyStopping(self.patience)
        for epoch in range(self.config.n_epochs):
            print('{} / {} Epoch'.format(epoch, self.config.n_epochs))
            train_loss = self.train_one_epoch(train_loader)
            print('[Train] loss: {}'.format(train_loss.avg))
            self.save(self.base_dir +
                      '{}_{}_last.pt'.format(self.name, self.fold_num))

            val_loss = self.val_one_epoch(val_loader)
            print('[Valid] loss: {}'.format(val_loss.avg))

            if val_loss.avg < self.best_loss:
                self.best_loss = val_loss.avg
                self.save(self.base_dir +
                          '{}_{}_best.pt'.format(self.name, self.fold_num))

            # Early stopping
            early_stopping(val_loss.avg, self.best_loss)
            if early_stopping.early_stop:
                break

            if self.config.val_scheduler:
                self.scheduler.step(metrics=val_loss.avg)

            self.epoch += 1
Exemple #24
0
class Trainer():
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            weight_file = download_weights(**config['pretrain'],
                                           quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char)
                print(info)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['img'], self.model)
            else:
                translated_sentence = translate(batch['img'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files

    def precision(self, sample=None):

        pred_sents, actual_sents, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')

        return acc_full_seq, acc_per_char

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

        for vis_idx in range(0, len(img_files)):
            img_path = img_files[vis_idx]
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]

            img = Image.open(open(img_path, 'rb'))
            plt.figure()
            plt.imshow(img)
            plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent),
                      loc='left',
                      fontdict=fontdict)
            plt.axis('off')

        plt.show()

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                plt.figure()
                plt.title('sent: {}'.format(sent),
                          loc='center',
                          fontname=fontname)
                plt.imshow(img)
                plt.axis('off')

                n += 1
                if n >= sample:
                    plt.show()
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        optim = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            self.config['transformer']['d_model'], **self.config['optimizer'])

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape'.format(name))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames']
        }

        return batch

    def data_gen(self, lmdb_path, data_root, annotation, transform=None):
        dataset = OCRDataset(
            lmdb_path=lmdb_path,
            root_dir=data_root,
            annotation_path=annotation,
            vocab=self.vocab,
            transform=transform,
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        sampler = ClusterRandomSampler(dataset, self.batch_size, True)
        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=False,
                         drop_last=False,
                         **self.config['dataloader'])

        return gen

    def data_gen_v1(self, lmdb_path, data_root, annotation):
        data_gen = DataGen(
            data_root,
            annotation,
            self.vocab,
            'cpu',
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        return data_gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  #flatten(0, 1)
        tgt_output = tgt_output.view(-1)  #flatten()

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item
        },
    ]
    lr = args.lr
    query_optimizer = AdamW(optimizer_grouped_parameters1, lr=lr, eps=1e-8)
    t_total = epoch_len * args.epochs
    num_warmup_steps = int(args.warmup * t_total)
    query_scheduler = get_linear_schedule_with_warmup(
        query_optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=t_total)
    if (ckpt_dir
            and os.path.isfile(os.path.join(ckpt_dir, "query_optimizer.pt"))
            and os.path.isfile(os.path.join(ckpt_dir, "query_scheduler.pt"))):
        # Load in optimizer and scheduler states
        query_optimizer.load_state_dict(
            torch.load(os.path.join(ckpt_dir, "query_optimizer.pt"),
                       map_location='cpu'))
        query_scheduler.load_state_dict(
            torch.load(os.path.join(ckpt_dir, "query_scheduler.pt"),
                       map_location='cpu'))
        logger.info(
            f'Load query optimizer states from {os.path.join(ckpt_dir, "query_optimizer.pt")}'
        )

    if not args.share:
        optimizer_grouped_parameters2 = [
            {
                "params": [
                    p for n, p in doc_bert.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
Exemple #26
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            inputs["token_type_ids"] = (
                batch[2]
                if args.model_type in ["bert", "xlnet", "albert"] else None
            )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if step % 10 == 0:
                print(step, loss.item())

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return global_step, tr_loss / global_step
Exemple #27
0
class Trainer():
    def __init__(self,
                 train_dataloader,
                 test_dataloader,
                 lr,
                 betas,
                 weight_decay,
                 log_freq,
                 with_cuda,
                 model=None):

        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda" if cuda_condition else "cpu")
        print("Use:", "cuda:0" if cuda_condition else "cpu")

        self.model = Classifier_M3().to(self.device)
        self.optim = AdamW(self.model.parameters(),
                           lr=lr,
                           betas=betas,
                           weight_decay=weight_decay)
        self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5)
        self.criterion = nn.BCEWithLogitsLoss()

        if model != None:
            checkpoint = torch.load(model)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.criterion = checkpoint['loss']

        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print("Using %d GPUS for Converter" % torch.cuda.device_count())

        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.log_freq = log_freq
        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

        self.test_loss = []
        self.train_loss = []
        self.train_f1_score = []
        self.test_f1_score = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        :param epoch: 現在のepoch
        :param data_loader: torch.utils.data.DataLoader
        :param train: trainかtestかのbool値
        """
        str_code = "train" if train else "test"

        data_iter = tqdm(enumerate(data_loader),
                         desc="EP_%s:%d" % (str_code, epoch),
                         total=len(data_loader),
                         bar_format="{l_bar}{r_bar}")

        total_element = 0
        loss_store = 0.0
        f1_score_store = 0.0
        total_correct = 0

        for i, data in data_iter:
            specgram = data[0].to(self.device)
            label = data[2].to(self.device)
            one_hot_label = data[1].to(self.device)
            predict_label = self.model(specgram, train)

            #
            predict_f1_score = get_F1_score(
                label.cpu().detach().numpy(),
                convert_label(predict_label.cpu().detach().numpy()),
                average='micro')

            loss = self.criterion(predict_label, one_hot_label)

            #
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.scheduler.step()

            loss_store += loss.item()
            f1_score_store += predict_f1_score
            self.avg_loss = loss_store / (i + 1)
            self.avg_f1_score = f1_score_store / (i + 1)

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": round(self.avg_loss, 5),
                "loss": round(loss.item(), 5),
                "avg_f1_score": round(self.avg_f1_score, 5)
            }

        data_iter.write(str(post_fix))
        self.train_loss.append(
            self.avg_loss) if train else self.test_loss.append(self.avg_loss)
        self.train_f1_score.append(
            self.avg_f1_score) if train else self.test_f1_score.append(
                self.avg_f1_score)

    def save(self, epoch, file_path="../models/2k/"):
        """
        """
        output_path = file_path + f"crnn_ep{epoch}.model"
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.cpu().state_dict(),
                'optimizer_state_dict': self.optim.state_dict(),
                'criterion': self.criterion
            }, output_path)
        self.model.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path

    def export_log(self, epoch, file_path="../../logs/2k/"):
        df = pd.DataFrame({
            "train_loss": self.train_loss,
            "test_loss": self.test_loss,
            "train_F1_score": self.train_f1_score,
            "test_F1_score": self.test_f1_score
        })
        output_path = file_path + f"loss_timestrech.log"
        print("EP:%d logs Saved on:" % epoch, output_path)
        df.to_csv(output_path)
Exemple #28
0
def run_training(opt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    work_dir, epochs, train_batch, valid_batch, weights = \
        opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights

    # Directories
    last = os.path.join(work_dir, 'last.pt')
    best = os.path.join(work_dir, 'best.pt')

    # --------------------------------------
    # Setup train and validation set
    # --------------------------------------
    data = pd.read_csv(opt.train_csv)
    images_path = opt.data_dir

    n_classes = 6  # fixed coding :V

    data['class'] = data.apply(lambda row: categ[row["class"]], axis=1)

    train_loader, val_loader = prepare_dataloader(data,
                                                  opt.fold,
                                                  train_batch,
                                                  valid_batch,
                                                  opt.img_size,
                                                  opt.num_workers,
                                                  data_root=images_path)

    # if not opt.ovr_val:
    #     handwritten_data = pd.read_csv(opt.handwritten_csv)
    #     printed_data = pd.read_csv(opt.printed_csv)
    #     handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1)
    #     printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1)
    #     _, handwritten_val_loader = prepare_dataloader(
    #         handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    #     _, printed_val_loader = prepare_dataloader(
    #         printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    # --------------------------------------
    # Models
    # --------------------------------------

    model = Classifier(model_name=opt.model_name,
                       n_classes=n_classes,
                       pretrained=True).to(device)

    if opt.weights is not None:
        cp = torch.load(opt.weights)
        model.load_state_dict(cp['model'])

    # -------------------------------------------
    # Setup optimizer, scheduler, criterion loss
    # -------------------------------------------

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    scaler = GradScaler()

    loss_tr = nn.CrossEntropyLoss().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # --------------------------------------
    # Setup training
    # --------------------------------------
    if os.path.exists(work_dir) == False:
        os.mkdir(work_dir)

    best_loss = 1e5
    start_epoch = 0
    best_epoch = 0  # for early stopping

    if opt.resume == True:
        checkpoint = torch.load(last)

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_loss = checkpoint["best_loss"]

    # --------------------------------------
    # Start training
    # --------------------------------------
    print("[INFO] Start training...")
    for epoch in range(start_epoch, epochs):
        train_one_epoch(epoch,
                        model,
                        loss_tr,
                        optimizer,
                        train_loader,
                        device,
                        scheduler=scheduler,
                        scaler=scaler)
        with torch.no_grad():
            if opt.ovr_val:
                val_loss = valid_one_epoch_overall(epoch,
                                                   model,
                                                   loss_fn,
                                                   val_loader,
                                                   device,
                                                   scheduler=None)
            else:
                val_loss = valid_one_epoch(epoch,
                                           model,
                                           loss_fn,
                                           handwritten_val_loader,
                                           printed_val_loader,
                                           device,
                                           scheduler=None)

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_loss': best_loss
                    }, os.path.join(best))

                print('best model found for epoch {}'.format(epoch + 1))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss
            }, os.path.join(last))

        if epoch - best_epoch > opt.patience:
            print("Early stop achieved at", epoch + 1)
            break

    del model, optimizer, train_loader, val_loader, scheduler, scaler
    torch.cuda.empty_cache()
Exemple #29
0
break_factor = False

# Measure the total training time for the whole run.
total_t0 = time.time()

print("starting...")
# For each epoch...
for epoch_i in range(0, EPOCHS):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 5, EPOCHS + 4))
    print('Training...')
    checkpoint = torch.load("/global/cscratch1/sd/ajaybati/model_ckptDS" +
                            str(epoch_i) + ".pickle")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch'] + 1
    total_train_loss = 0
    step_resume = 0
    training_stats = checkpoint['training_stats']
    print('step:  ', step_resume, 'total loss:  ', total_train_loss,
          'epoch:   ', epoch)

    # Measure how long the training epoch takes.
    t0 = time.time()

    model.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)