def create_opt(self) -> torch.optim.Optimizer: opt = AdamW( self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay, ) if os.path.exists(self.opt_path()): print("loading optimizer from checkpoint...") opt.load_state_dict(torch.load(self.opt_path(), map_location="cpu")) return opt
def main(): parser = argparse.ArgumentParser( description='20bn-jester-v1 Gesture Classification with Backpropamine') parser.add_argument('--batch-size', type=int, default=8, metavar='N', help='input batch size for training (default: 8)') #parser.add_argument('--validation-batch-size', type=int, default=1000, metavar='N', # help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: 100)') parser.add_argument('--num-workers', type=int, default=0, metavar='W', help='number of workers for data loading (default: 0)') parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)') parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='Learning rate step gamma (default: 0.7)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--dry-run', action='store_true', default=False, help='quickly check a single pass') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--dataset-dir', type=str, default=r"./dataset", metavar='D', help='dataset place (default: ./dataset)') #parser.add_argument('--log-interval', type=int, default=10, metavar='N', # help='how many batches to wait before logging training status') #parser.add_argument('--save-model', action='store_true', default=False, # help='For Saving the current Model') parser.add_argument('--no-resume', action='store_true', default=False, help='switch to disables resume') parser.add_argument( '--use-lstm', action='store_true', default=False, help='switch to use LSTM module instead of backpropamine') parser.add_argument('--frame-step', type=int, default=2, metavar='FS', help='step of video frames extraction (default: 2)') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') torch.manual_seed(args.seed) train_data = MyDataset('train', args.dataset_dir, frame_step=args.frame_step) validation_data = MyDataset('validation', args.dataset_dir, frame_step=args.frame_step) train_dataloader = DataLoader(train_data, batch_size=args.batch_size, drop_last=True, shuffle=True, collate_fn=collate_fn, num_workers=args.num_workers) validation_dataloader = DataLoader(validation_data, batch_size=args.batch_size, drop_last=True, shuffle=True, collate_fn=collate_fn, num_workers=args.num_workers) resume = not args.no_resume if resume: try: checkpoint = torch.load("checkpoint.pt") except FileNotFoundError: resume = False mode = 'LSTM' if args.use_lstm else 'backpropamine' model = Net(mode=mode).to(device) optimizer = AdamW(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) last_epoch, max_epoch = 0, args.epochs if resume: model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) last_epoch = checkpoint['last_epoch'] validator = Validator(model, validation_dataloader, device, args.dry_run) trainer = Trainer(model, optimizer, train_dataloader, scheduler, last_epoch, max_epoch, device, validator, args.dry_run) print(vars(args)) trainer() print("finish.")
class TrainLoop: def __init__( self, *, model, diffusion, data, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, ): self.model = model self.diffusion = diffusion self.data = data self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ( [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] ) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size * dist.get_world_size() self.model_params = list(self.model.parameters()) self.master_params = self.model_params self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() if self.use_fp16: self._setup_fp16() self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) ] if th.cuda.is_available(): self.use_ddp = True self.ddp_model = DDP( self.model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) else: if dist.get_world_size() > 1: logger.warn( "Distributed training requires CUDA. " "Gradients will not be synchronized properly!" ) self.use_ddp = False self.ddp_model = self.model def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") self.model.load_state_dict( dist_util.load_state_dict( resume_checkpoint, map_location=dist_util.dev() ) ) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate): ema_params = copy.deepcopy(self.master_params) main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=dist_util.dev() ) ema_params = self._state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) return ema_params def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join( bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" ) if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev() ) self.opt.load_state_dict(state_dict) def _setup_fp16(self): self.master_params = make_master_params(self.model_params) self.model.convert_to_fp16() def run_loop(self): while ( not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps ): batch, cond = next(self.data) self.run_step(batch, cond) if self.step % self.log_interval == 0: logger.dumpkvs() if self.step % self.save_interval == 0: self.save() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return self.step += 1 # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0: self.save() def run_step(self, batch, cond): self.forward_backward(batch, cond) if self.use_fp16: self.optimize_fp16() else: self.optimize_normal() self.log_step() def forward_backward(self, batch, cond): zero_grad(self.model_params) for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = { k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items() } last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, t, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach() ) loss = (losses["loss"] * weights).mean() log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) if self.use_fp16: loss_scale = 2 ** self.lg_loss_scale (loss * loss_scale).backward() else: loss.backward() def optimize_fp16(self): if any(not th.isfinite(p.grad).all() for p in self.model_params): self.lg_loss_scale -= 1 logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") return model_grads_to_master_grads(self.model_params, self.master_params) self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) master_params_to_model_params(self.model_params, self.master_params) self.lg_loss_scale += self.fp16_scale_growth def optimize_normal(self): self._log_grad_norm() self._anneal_lr() self.opt.step() for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.master_params, rate=rate) def _log_grad_norm(self): sqsum = 0.0 for p in self.master_params: sqsum += (p.grad ** 2).sum().item() logger.logkv_mean("grad_norm", np.sqrt(sqsum)) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) if self.use_fp16: logger.logkv("lg_loss_scale", self.lg_loss_scale) def save(self): def save_checkpoint(rate, params): state_dict = self._master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def _master_params_to_state_dict(self, master_params): if self.use_fp16: master_params = unflatten_master_params( self.model.parameters(), master_params ) state_dict = self.model.state_dict() for i, (name, _value) in enumerate(self.model.named_parameters()): assert name in state_dict state_dict[name] = master_params[i] return state_dict def _state_dict_to_master_params(self, state_dict): params = [state_dict[name] for name, _ in self.model.named_parameters()] if self.use_fp16: return make_master_params(params) else: return params
], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) #----------------------------------------- #----------------------------------------- # Loading the contents of the auxiliary checkpoint and instantiating the contents if not resuming: if aux_checkpoint: global_step = aux_checkpoint['global_step'] epoch = aux_checkpoint['epoch'] optimizer.load_state_dict(aux_checkpoint['optimizer']) best_acc = aux_checkpoint['best_acc'] mlp.load_state_dict(aux_checkpoint['mlp_state_dict']) best_checkpoint_path = aux_checkpoint['best_checkpoint_path'] else: global_step = 0 best_acc = 0.0 epoch = 0 best_checkpoint_path = None #----------------------------------------- #----------------------------------------- # Enabling the use of dataparallel for multiple GPUs: if args.dataparallel: model = nn.DataParallel(model) #-----------------------------------------
def train(args): logger = log.get_logger(__name__) with open(Path(args.config_base_path, args.config).with_suffix(".yaml"), 'r') as f: config = yaml.safe_load(f) train_transforms = transforms.get_train_transforms() val_transforms = transforms.get_val_transforms() logger.info("Loading the dataset...") if config['dataset']['name'] == 'coco_subset': # TODO: Look into train_transforms hiding the objects # Transform in such a way that this can't be the case train_dataset = CocoSubset(config['dataset']['coco_path'], config['dataset']['target_classes'], train_transforms, 'train', config['dataset']['train_val_split']) val_dataset = CocoSubset(config['dataset']['coco_path'], config['dataset']['target_classes'], val_transforms, 'val', config['dataset']['train_val_split']) else: raise ValueError("Dataset name not recognized or implemented") train_loader = DataLoader(train_dataset, config['training']['batch_size'], shuffle=True, collate_fn=data_utils.collate_fn) val_loader = DataLoader(val_dataset, config['training']['batch_size'], shuffle=True, collate_fn=data_utils.collate_fn) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_manager = CheckpointManager(args.config, args.save_every) logger.info("Loading model...") model = models.DETR(config['dataset']['num_classes'], config['model']['dim_model'], config['model']['n_heads'], n_queries=config['model']['n_queries'], head_type=config['model']['head_type']) # TODO: implement scheduler optim = AdamW(model.parameters(), config['training']['lr']) # pending if args.mode == 'pretrained': model.load_demo_state_dict('data/state_dicts/detr_demo.pth') elif args.mode == 'checkpoint': state_dict, optim_dict = checkpoint_manager.load_checkpoint('latest') model.load_state_dict(state_dict) optim.load_state_dict(optim_dict) if args.train_section == 'head': to_train = ['ffn'] elif args.train_section == 'backbone': to_train = ['backbone', 'conv'] else: to_train = ['ffn', 'backbone', 'conv', 'transformer', 'row', 'col', 'object'] # Freeze everything but the modules that are in to_train for name, param in model.named_parameters(): if not any(map(name.startswith, to_train)): param.requires_grad = False model.to(device) matcher = models.HungarianMatcher(config['losses']['lambda_matcher_classes'], config['losses']['lambda_matcher_giou'], config['losses']['lambda_matcher_l1']) loss_fn = models.DETRLoss(config['losses']['lambda_loss_classes'], config['losses']['lambda_loss_giou'], config['losses']['lambda_loss_l1'], config['dataset']['num_classes'], config['losses']['no_class_weight']) # writer = SummaryWriter(log_dir=Path(__file__)/'logs/tensorboard') # maybe image with boxes every now and then # maybe look into add_hparams logger.info("Starting training...") loss_hist = deque() loss_desc = "Loss: n/a" update_every_n_steps = config['training']['effective_batch_size'] // config['training']['batch_size'] steps = 1 starting_epoch = checkpoint_manager.current_epoch for epoch in range(starting_epoch, config['training']['epochs']): epoch_desc = f"Epoch [{epoch}/{config['training']['epochs']}]" for images, labels in tqdm(train_loader, f"{epoch_desc} | {loss_desc}"): images = images.to(device) labels = data_utils.labels_to_device(labels, device) output = model(images) matching_indices = matcher(output, labels) matching_indices = data_utils.indices_to_device(matching_indices, device) loss = loss_fn(output, labels, matching_indices) / update_every_n_steps loss_hist.append(loss.item() * update_every_n_steps) loss.backward() if steps % update_every_n_steps == 0: optim.step() optim.zero_grad() steps += 1 checkpoint_manager.step(model, optim, sum(loss_hist) / len(loss_hist)) loss_desc = f"Loss: {sum(loss_hist)/len(loss_hist)}" loss_hist.clear() if (epoch % args.eval_every == 0) and epoch != 0: validation_loop(model, matcher, val_loader, loss_fn, device) checkpoint_manager.save_checkpoint(model, optim, sum(loss_hist) / len(loss_hist))
def fit(self): config = self.config logging.debug(json.dumps(config, indent=4, sort_keys=True)) include_passage_masks = self.config["fusion_strategy"] == "passages" if self.config["dataset"] in ["nq", "trivia"]: fields = FusionInDecoderDataset.prepare_fields( pad_t=self.tokenizer.pad_token_id) if not config["test_only"]: # trivia is too large, create lightweight training dataset for it instead training_dataset = FusionInDecoderDatasetLight if self.config \ .get("use_lightweight_dataset", False) else FusionInDecoderDataset train = training_dataset(config["train_data"], fields=fields, tokenizer=self.tokenizer, database=self.db, transformer=config["reader_transformer_type"], cache_dir=self.config["data_cache_dir"], max_len=self.config.get("reader_max_input_length", None), context_length=self.config["context_length"], include_golden_passage=self.config["include_golden_passage_in_training"], include_passage_masks=include_passage_masks, preprocessing_truncation=self.config["preprocessing_truncation"], one_answer_per_question=self.config.get("one_question_per_epoch", False), use_only_human_answer=self.config.get("use_human_answer_only", False), is_training=True) val = FusionInDecoderDataset(config["val_data"], fields=fields, tokenizer=self.tokenizer, database=self.db, transformer=config["reader_transformer_type"], cache_dir=config["data_cache_dir"], max_len=self.config.get("reader_max_input_length", None), context_length=self.config["context_length"], include_passage_masks=include_passage_masks, preprocessing_truncation=self.config["preprocessing_truncation"], use_only_human_answer=self.config.get("use_human_answer_only", False), is_training=False) test = FusionInDecoderDataset(config["test_data"], fields=fields, tokenizer=self.tokenizer, database=self.db, transformer=config["reader_transformer_type"], cache_dir=config["data_cache_dir"], max_len=self.config.get("reader_max_input_length", None), context_length=self.config["context_length"], include_passage_masks=include_passage_masks, preprocessing_truncation=self.config["preprocessing_truncation"], is_training=False) else: raise NotImplemented(f"Unknown dataset {self.config['dataset']}") if not config["test_only"]: logging.info(f"Training data examples:{len(train)}") logging.info(f"Validation data examples:{len(val)}") logging.info(f"Test data examples {len(test)}") if not config["test_only"]: train_iter = Iterator(train, shuffle=training_dataset != FusionInDecoderDatasetLight, sort=False, # do not sort! batch_size=1, train=True, repeat=False, device=self.device) val_iter = Iterator(val, sort=False, shuffle=False, batch_size=1, repeat=False, device=self.device) test_iter = Iterator(test, sort=False, shuffle=False, batch_size=1, repeat=False, device=self.device) logging.info("Loading model...") if config.get("resume_training", False) or config.get("pre_initialize", False): if config.get("resume_training", False): logging.info("Resuming training...") if not "resume_checkpoint" in config: config["resume_checkpoint"] = config["pretrained_reader_model"] model = torch.load(config["resume_checkpoint"], map_location=self.device) else: model = torch.load(config["model"], map_location=self.device) \ if self.config["test_only"] and "model" in config else \ T5FusionInDecoder.from_pretrained(config).to(self.device) logging.info(f"Resizing token embeddings to length {len(self.tokenizer)}") model.resize_token_embeddings(len(self.tokenizer)) logging.info(f"Model has {count_parameters(model)} trainable parameters") logging.info(f"Trainable parameter checksum: {sum_parameters(model)}") param_sizes, param_shapes = report_parameters(model) param_sizes = "\n'".join(str(param_sizes).split(", '")) param_shapes = "\n'".join(str(param_shapes).split(", '")) logging.debug(f"Model structure:\n{param_sizes}\n{param_shapes}\n") if not config["test_only"]: # Init optimizer no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.config["weight_decay"], }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] if config["optimizer"] == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=self.config["learning_rate"], eps=self.config["adam_eps"]) elif config["optimizer"] == "adam": optimizer = Adam(optimizer_grouped_parameters, lr=self.config["learning_rate"], eps=self.config["adam_eps"]) else: raise ValueError("Unsupported optimizer") if config.get("resume_checkpoint", False): optimizer.load_state_dict(model.optimizer_state_dict) # Init scheduler if "scheduler_warmup_steps" in self.config or "warmup_proportion" in self.config: t_total = self.config["max_steps"] warmup_steps = round( self.config[ "scheduler_warmup_proportion"] * t_total) if "scheduler_warmup_proportion" in self.config else \ self.config["scheduler_warmup_steps"] scheduler = self.init_scheduler( optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total, last_step=get_model(model).training_steps - 1 ) logging.info(f"Scheduler: warmup steps: {warmup_steps}, total_steps: {t_total}") else: scheduler = None if config["lookahead_optimizer"]: optimizer = Lookahead(optimizer, k=10, alpha=0.5) if not config["test_only"]: start_time = time.time() try: it = 0 while get_model(model).training_steps < self.config["max_steps"]: logging.info(f"Epoch {it}") train_loss = self.train_epoch(model=model, data_iter=train_iter, val_iter=val_iter, optimizer=optimizer, scheduler=scheduler) logging.info(f"Training loss: {train_loss:.5f}") it += 1 except KeyboardInterrupt: logging.info('-' * 120) logging.info('Exit from training early.') finally: logging.info(f'Finished after {(time.time() - start_time) / 60} minutes.') if hasattr(self, "best_ckpt_name"): logging.info(f"Loading best checkpoint {self.best_ckpt_name}") model = torch.load(self.best_ckpt_name, map_location=self.device) logging.info("#" * 50) logging.info("Validating on the test data") self.validate(model, test_iter)
def pretrain(data, stats=None): # fine tuning dataloader = DataLoader(data, batch_size=1, shuffle=True) del data ## optimizer and scheduler ## t_total = len( dataloader) // opts.gradient_accumulation_steps * opts.num_train_epochs no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [{ "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": opts.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=opts.lr, eps=opts.eps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=opts.warmup_steps, num_training_steps=t_total) # loading optimizer settings if (opts.model_name_or_path and os.path.isfile( os.path.join(opts.model_name_or_path, "pretrain_optimizer.pt")) and os.path.isfile( os.path.join(opts.model_name_or_path, "scheduler.pt"))): # load optimizer and scheduler states optimizer.load_state_dict( torch.load( os.path.join(opts.model_name_or_path, "pretrain_optimizer.pt"))) scheduler.load_state_dict( torch.load( os.path.join(opts.model_name_or_path, "pretrain_scheduler.pt"))) # track stats if stats is not None: global_step = max(stats.keys()) epochs_trained = global_step // (len(dataloader) // opts.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(dataloader) // opts.gradient_accumulation_steps) print("Resuming Training ... ") else: stats = {} global_step, epochs_trained, steps_trained_in_current_epoch = 0, 0, 0 tr_loss, logging_loss = 0.0, 0.0 # very important: set model to TRAINING mode model.zero_grad() model.train() print("Re-sizing model ... ") model.resize_token_embeddings(len(tokenizer)) start_time = time.time() for epoch in range(epochs_trained, opts.num_train_epochs): data_iter = iter(dataloader) for step in range(len(dataloader)): if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 batch = data_iter.next() continue ### step ### batch = data_iter.next() loss = fit_on_batch(batch) del batch # logging (new data only) tr_loss += loss.item() # gradient accumulation if (step + 1) % opts.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opts.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 # reporting if global_step % opts.logging_steps == 0: stats[global_step] = { 'pretrain_loss': (tr_loss - logging_loss) / opts.logging_steps, 'pretrain_lr': scheduler.get_last_lr()[-1] } logging_loss = tr_loss elapsed_time = time.strftime( "%M:%S", time.gmtime(time.time() - start_time)) print( 'Epoch: %d | Iter: [%d/%d] | loss: %.3f | lr: %s | time: %s' % (epoch, global_step, t_total, stats[global_step]['pretrain_loss'], str(stats[global_step]['pretrain_lr']), elapsed_time)) start_time = time.time() if global_step % opts.save_steps == 0: print("Saving stuff ... ") checkpoint(model, tokenizer, optimizer, scheduler, stats, title="pretrain_") plot_losses(stats, title='pretrain_loss') plot_losses(stats, title='pretrain_lr') print("Done.") return stats
def train(self): # get dataloader train_dataloader, _ = self.data2loader(self.args['train'], mode='train', batch_size=self.args['batch_size']) # optimizer and scheduler param_optimizer = list(self.model.named_parameters()) other_parameters = [(n, p) for n, p in param_optimizer if 'crf' not in n] no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [ {'params': [p for n, p in other_parameters if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in other_parameters if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}, {'params':[p for n, p in param_optimizer if 'crf.transitions' == n], 'lr':3e-2} ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.args['learning_rate'], eps=1e-8) if self.args['load_model'] > 0: optimizer.load_state_dict(torch.load('models/Opt' + str(self.args['load_model']))) print('load optimizer success') total_steps = 1000#len(train_dataloader) * num_epoches if self.args['load_model'] <= 0: last_epoch = -1 else: last_epoch = self.args['load_model'] scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps, last_epoch=last_epoch ) # training top = 0 stop = 0 best_model = None start_time = time() for i in range(self.args['num_epoches']): self.model.train() losses = 0 for idx, batch_data in enumerate(train_dataloader): batch_data = tuple(i.to(device) for i in batch_data) ids, masks, labels = batch_data self.model.zero_grad() loss = self.model(ids, masks=masks, labels=labels) # process loss loss.backward() losses += loss.item() # tackle exploding gradients torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=self.args['max_grad_norm']) optimizer.step() scheduler.step() F0 = None if (i+1+self.args['load_model']) % 20 == 0: F0, _ = self.evaluate(self.args['train']) F1, loss = self.evaluate(self.args['valid']) F2, loss2 = self.evaluate(self.args['test']) if F1+F2 > top: top = F1 + F2 # torch.save(self.model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1)) best_model = copy.deepcopy(self.model) print('save new top', top) stop = 0 else: if stop > 7: torch.save(best_model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1)) return stop += 1 print('Epoch', i+self.args['load_model']+1, losses/len(train_dataloader), loss, 'F1', F1, F2, F0, time()-start_time) if (i+1+self.args['load_model']) % self.args['save_epoch'] == 0: torch.save(self.model, 'models/Mod' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1)) # torch.save(optimizer.state_dict(), 'models/Opt' + str(self.args['fold']) + '_' + str(i+self.args['load_model']+1)) start_time = time()
class Trainer: """ Handles model training and evaluation. Arguments: ---------- config: A dictionary of training parameters, likely from a .yaml file model: A pytorch segmentation model (e.g. DeepLabV3) trn_data: A pytorch dataloader object that will return pairs of images and segmentation masks from a training dataset val_data: A pytorch dataloader object that will return pairs of images and segmentation masks from a validation dataset. """ def __init__(self, config, model, trn_data, val_data=None): self.config = config self.model = model.cuda() self.trn_data = DataFetcher(trn_data) self.val_data = val_data #create the optimizer if config['optim'] == 'SGD': self.optimizer = SGD(model.parameters(), lr=config['lr'], momentum=config['momentum'], weight_decay=config['wd']) elif config['optim'] == 'AdamW': self.optimizer = AdamW( model.parameters(), lr=config['lr'], weight_decay=config['wd']) #momentum is default else: optim = config['optim'] raise Exception( f'Optimizer {optim} is not supported! Must be SGD or AdamW') #create the learning rate scheduler schedule = config['lr_policy'] if schedule == 'OneCycle': self.scheduler = OneCycleLR(self.optimizer, config['lr'], total_steps=config['iters']) elif schedule == 'MultiStep': self.scheduler = MultiStepLR(self.optimizer, milestones=config['lr_decay_epochs']) elif schedule == 'Poly': func = lambda iteration: (1 - (iteration / config['iters']) )**config['power'] self.scheduler = LambdaLR(self.optimizer, func) else: lr_policy = config['lr_policy'] raise Exception( f'Policy {lr_policy} is not supported! Must be OneCycle, MultiStep or Poly' ) #create the loss criterion if config['num_classes'] > 1: #load class weights if they were given in the config file if 'class_weights' in config: weight = torch.Tensor(config['class_weights']).float().cuda() else: weight = None self.criterion = nn.CrossEntropyLoss(weight=weight).cuda() else: self.criterion = nn.BCEWithLogitsLoss().cuda() #define train and validation metrics and class names class_names = config['class_names'] #make training metrics using the EMAMeter. this meter gives extra #weight to the most recent metric values calculated during training #this gives a better reflection of how well the model is performing #when the metrics are printed trn_md = { name: metric_lookup[name](EMAMeter()) for name in config['metrics'] } self.trn_metrics = ComposeMetrics(trn_md, class_names) self.trn_loss_meter = EMAMeter() #the only difference between train and validation metrics #is that we use the AverageMeter. this is because there are #no weight updates during evaluation, so all batches should #count equally val_md = { name: metric_lookup[name](AverageMeter()) for name in config['metrics'] } self.val_metrics = ComposeMetrics(val_md, class_names) self.val_loss_meter = AverageMeter() self.logging = config['logging'] #now, if we're resuming from a previous run we need to load #the state for the model, optimizer, and schedule and resume #the mlflow run (if there is one and we're using logging) if config['resume']: self.resume(config['resume']) elif self.logging: #if we're not resuming, but are logging, then we #need to setup mlflow with a new experiment #everytime that Trainer is instantiated we want to #end the current active run and let a new one begin mlflow.end_run() #extract the experiment name from config so that #we know where to save our files, if experiment name #already exists, we'll use it, otherwise we create a #new experiment mlflow.set_experiment(self.config['experiment_name']) #add the config file as an artifact mlflow.log_artifact(config['config_file']) #we don't want to add everything in the config #to mlflow parameters, we'll just add the most #likely to change parameters mlflow.log_param('lr_policy', config['lr_policy']) mlflow.log_param('optim', config['optim']) mlflow.log_param('lr', config['lr']) mlflow.log_param('wd', config['wd']) mlflow.log_param('bsz', config['bsz']) mlflow.log_param('momentum', config['momentum']) mlflow.log_param('iters', config['iters']) mlflow.log_param('epochs', config['epochs']) mlflow.log_param('encoder', config['encoder']) mlflow.log_param('finetune_layer', config['finetune_layer']) mlflow.log_param('pretraining', config['pretraining']) def resume(self, checkpoint_fpath): """ Sets model parameters, scheduler and optimizer states to the last recorded values in the given checkpoint file. """ checkpoint = torch.load(checkpoint_fpath, map_location='cpu') self.model.load_state_dict(checkpoint['state_dict']) if not self.config['restart_training']: self.scheduler.load_state_dict(checkpoint['scheduler']) self.optimizer.load_state_dict(checkpoint['optimizer']) if self.logging and 'run_id' in checkpoint: mlflow.start_run(run_id=checkpoint['run_id']) print(f'Loaded state from {checkpoint_fpath}') print(f'Resuming from epoch {self.scheduler.last_epoch}...') def log_metrics(self, step, dataset): #get the corresponding losses and metrics dict for #either train or validation sets if dataset == 'train': losses = self.trn_loss_meter metric_dict = self.trn_metrics.metrics_dict elif dataset == 'valid': losses = self.val_loss_meter metric_dict = self.val_metrics.metrics_dict #log the last loss, using the dataset name as a prefix mlflow.log_metric(dataset + '_loss', losses.avg, step=step) #log all the metrics in our dict, using dataset as a prefix metrics = {} for k, v in metric_dict.items(): values = v.meter.avg for class_name, val in zip(self.trn_metrics.class_names, values): metrics[dataset + '_' + class_name + '_' + k] = float( val.item()) mlflow.log_metrics(metrics, step=step) def train(self): """ Defines a pytorch style training loop for the model withtqdm progress bar for each epoch and handles printing loss/metrics at the end of each epoch. epochs: Number of epochs to train model train_iters_per_epoch: Number of training iterations is each epoch. Reducing this number will give more frequent updates but result in slower training time. Results: ---------- After train_iters_per_epoch iterations are completed, it will evaluate the model on val_data if there is any, then prints loss and metrics for train and validation datasets. """ #set the inner and outer training loop as either #iterations or epochs depending on our scheduler if self.config['lr_policy'] != 'MultiStep': last_epoch = self.scheduler.last_epoch + 1 total_epochs = self.config['iters'] iters_per_epoch = 1 outer_loop = tqdm(range(last_epoch, total_epochs + 1), file=sys.stdout, initial=last_epoch, total=total_epochs) inner_loop = range(iters_per_epoch) else: last_epoch = self.scheduler.last_epoch + 1 total_epochs = self.config['epochs'] iters_per_epoch = len(self.trn_data) outer_loop = range(last_epoch, total_epochs + 1) inner_loop = tqdm(range(iters_per_epoch), file=sys.stdout) #determine the epochs at which to print results eval_epochs = total_epochs // self.config['num_prints'] save_epochs = total_epochs // self.config['num_save_checkpoints'] #the cudnn.benchmark flag speeds up performance #when the model input size is constant. See: #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 cudnn.benchmark = True #perform training over the outer and inner loops for epoch in outer_loop: for iteration in inner_loop: #load the next batch of training data images, masks = self.trn_data.load() #run the training iteration loss, output = self._train_1_iteration(images, masks) #record the loss and evaluate metrics self.trn_loss_meter.update(loss) self.trn_metrics.evaluate(output, masks) #when we're at an eval_epoch we want to print #the training results so far and then evaluate #the model on the validation data if epoch % eval_epochs == 0: #before printing results let's record everything in mlflow #(if we're using logging) if self.logging: self.log_metrics(epoch, dataset='train') print('\n') #print a new line to give space from progess bar print(f'train_loss: {self.trn_loss_meter.avg:.3f}') self.trn_loss_meter.reset() #prints and automatically resets the metric averages to 0 self.trn_metrics.print() #run evaluation if we have validation data if self.val_data is not None: #before evaluation we want to turn off cudnn #benchmark because the input sizes of validation #images are not necessarily constant cudnn.benchmark = False self.evaluate() if self.logging: self.log_metrics(epoch, dataset='valid') print( '\n') #print a new line to give space from progess bar print(f'valid_loss: {self.val_loss_meter.avg:.3f}') self.val_loss_meter.reset() #prints and automatically resets the metric averages to 0 self.val_metrics.print() #turn cudnn.benchmark back on before returning to training cudnn.benchmark = True #update the optimizer schedule self.scheduler.step() #the last step is to save the training state if #at a checkpoint if epoch % save_epochs == 0: self.save_state(epoch) def _train_1_iteration(self, images, masks): #run a training step self.model.train() self.optimizer.zero_grad() #forward pass output = self.model(images) loss = self.criterion(output, masks) #backward pass loss.backward() self.optimizer.step() #return the loss value and the output return loss.item(), output.detach() def evaluate(self): """ Evaluation method used at the end of each epoch. Not intended to generate predictions for validation dataset, it only returns average loss and stores metrics for validaiton dataset. Use Validator class for generating masks on a dataset. """ #set the model into eval mode self.model.eval() val_iter = DataFetcher(self.val_data) for _ in range(len(val_iter)): with torch.no_grad(): #load batch of data images, masks = val_iter.load() output = self.model.eval()(images) loss = self.criterion(output, masks) self.val_loss_meter.update(loss.item()) self.val_metrics.evaluate(output.detach(), masks) #loss and metrics are updated inplace, so there's nothing to return return None def save_state(self, epoch): """ Saves the self.model state dict Arguments: ------------ save_path: Path of .pt file for saving Example: ---------- trainer = Trainer(...) trainer.save_model(model_path + 'new_model.pt') """ #save the state together with the norms that we're using state = { 'state_dict': self.model.state_dict(), 'scheduler': self.scheduler.state_dict(), 'optimizer': self.optimizer.state_dict(), 'norms': self.config['training_norms'] } if self.logging: state['run_id'] = mlflow.active_run().info.run_id #the last step is to create the name of the file to save #the format is: name-of-experiment_pretraining_epoch.pth model_dir = self.config['model_dir'] exp_name = self.config['experiment_name'] pretraining = self.config['pretraining'] ft_layer = self.config['finetune_layer'] if self.config['lr_policy'] != 'MultiStep': total_epochs = self.config['iters'] else: total_epochs = self.config['epochs'] if os.path.isfile(pretraining): #this is slightly clunky, but it handles the case #of using custom pretrained weights from a file #usually there aren't any '.'s other than the file #extension pretraining = pretraining.split('/')[-2] #.split('.')[0] save_path = os.path.join( model_dir, f'{exp_name}-{pretraining}_ft_{ft_layer}_epoch{epoch}_of_{total_epochs}.pth' ) torch.save(state, save_path)
class Model: def __init__(self, epochs=50, fc=FC_62): self.epochs = epochs self.model = CNN(fc) self.model.to(device) self.num_epochs = epochs self.epochs = 0 self.loss = 0 self.optimizer = AdamW(params=self.model.parameters()) self.loss_fn = nn.CrossEntropyLoss() self.transform2 = [ # transforms.CenterCrop(256), # Crop(28), # transforms.Resize(256), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] def load(self, path): checkpoint = torch.load(path) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(device) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.epochs = checkpoint['epoch'] self.loss = checkpoint['loss'] print(f'\nmodel loaded from path : {path}') def save(self, epoch, model, optimizer, loss, path): save_path = root_dir + '/models/' if os.path.isdir(save_path) == False: os.makedirs(save_path) path = save_path + path torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, path) print(f'\nsaved model to path : {path}') def test(self, testloader, progress, type='validation'): print(f'Starting testing on {type} dataset') print('-------------------------------') correct, total = 0, 0 with torch.no_grad(): for i, data in enumerate(testloader, 0): inputs, targets = data inputs, targets = inputs.to(device), targets.to(device) outputs = self.model(inputs) _, predicted = torch.max(outputs.data, 1) # print(predicted) # print(targets) total += targets.size(0) correct += (predicted == targets).sum().item() progress.update(self.batch_size) print( f'\nAccuracy on {type} dataset : {correct} / {total} = {100.0 * correct / total}' ) print('--------------------------------') return 100.0 * correct / total def train(self, trainloader, epoch, progress): print(f'\nStarting epoch {epoch+1}') current_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, targets = data inputs, targets = inputs.to(device), targets.to(device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.loss_fn(outputs, targets) loss.backward() self.optimizer.step() current_loss += loss.item() progress.update(self.batch_size) print(f'\nloss at epoch {epoch + 1} : {current_loss}') return current_loss def train_validate(self, name, mnist=False, batch_size=64, validation_split=0, save_name=None): self.batch_size = batch_size if save_name is None: save_name = name progress = None np.random.seed(42) epochs_plot = [] accuracy_plot = [] loss_plot = [] for epoch in range(0, self.num_epochs): if mnist: self.transform1 = [ transforms.RandomRotation(degrees=10), ] train_data = torchvision.datasets.MNIST( 'mnist', download=True, transform=transforms.Compose(self.transform1 + self.transform2)) trainloader = torch.utils.data.DataLoader( train_data, batch_size=self.batch_size, num_workers=2) dataset_size = len(trainloader.dataset) else: data = get_data_set(name) dataset_size = len(data) ids = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) np.random.shuffle(ids) train_ids, val_ids = ids[split:], ids[:split] train_subsampler = torch.utils.data.SubsetRandomSampler( train_ids) test_subsampler = torch.utils.data.SubsetRandomSampler(val_ids) trainloader = torch.utils.data.DataLoader( data, batch_size=batch_size, sampler=train_subsampler, num_workers=2) testloader = torch.utils.data.DataLoader( data, batch_size=batch_size, sampler=test_subsampler, num_workers=2) if progress is None: progress = tqdm.tqdm(total=(2 + validation_split) * dataset_size * self.num_epochs, position=0, leave=True) current_loss = self.train(trainloader, epoch, progress) accuracy = self.test(trainloader, progress, 'train') if validation_split: self.test(testloader, progress, 'validation') epochs_plot.append(epoch) accuracy_plot.append(accuracy) loss_plot.append(current_loss) self.save(epoch, self.model, self.optimizer, current_loss, f'{save_name}-{epoch}.pth') return epochs_plot, accuracy_plot, loss_plot def test_mnist(self): test_data = torchvision.datasets.MNIST('mnist', False, download=True, transform=transforms.Compose( self.transform2)) testloader = torch.utils.data.DataLoader(test_data, batch_size=self.batch_size, num_workers=2) progress = tqdm.tqdm(total=len(testloader.dataset), position=0, leave=True) self.test(testloader, progress, 'test')
class Detector(object): def __init__(self, cfg): self.device = cfg["device"] self.model = Models().get_model(cfg["network"]) # cfg.network self.model.to(self.device) params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = AdamW(params, lr=0.00001) self.lr_scheduler = OneCycleLR(self.optimizer, max_lr=1e-4, epochs=cfg["nepochs"], steps_per_epoch=169, # len(dataloader)/accumulations div_factor=25, # for initial lr, default: 25 final_div_factor=1e3, # for final lr, default: 1e4 ) def fit(self, data_loader, accumulation_steps=4, wandb=None): self.model.train() # metric_logger = utils.MetricLogger(delimiter=" ") # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) avg_loss = MetricLogger('scalar') total_loss = MetricLogger('dict') lr_log = MetricLogger('list') self.optimizer.zero_grad() device = self.device for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = self.model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.detach().item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) losses.backward() if (i+1) % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() lr_log.update(self.lr_scheduler.get_last_lr()) print(f"\rTrain iteration: [{i+1}/{len(data_loader)}]", end="") avg_loss.update(loss_value) total_loss.update(loss_dict) # metric_logger.update(loss=losses_reduced, **loss_dict_reduced) # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) print() #print(loss_dict) return {"train_avg_loss": avg_loss.avg}, total_loss.avg def mixup_fit(self, data_loader, accumulation_steps=4, wandb=None): self.model.train() torch.cuda.empty_cache() # metric_logger = utils.MetricLogger(delimiter=" ") # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) avg_loss = MetricLogger('scalar') total_loss = MetricLogger('dict') #lr_log = MetricLogger('list') self.optimizer.zero_grad() device = self.device for i, (batch1, batch2) in enumerate(data_loader): images1, targets1 = batch1 images2, targets2 = batch2 images = mixup_images(images1, images2) targets = merge_targets(targets1, targets2) del images1, images2, targets1, targets2, batch1, batch2 images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = self.model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.detach().item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) losses.backward() if (i+1) % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() if self.lr_scheduler is not None: self.lr_scheduler.step() #lr_log.update(self.lr_scheduler.get_last_lr()) print(f"Train iteration: [{i+1}/{674}]\r", end="") avg_loss.update(loss_value) total_loss.update(loss_dict) # metric_logger.update(loss=losses_reduced, **loss_dict_reduced) # metric_logger.update(lr=optimizer.param_groups[0]["lr"]) print() #print(loss_dict) return {"train_avg_loss": avg_loss.avg}, total_loss.avg def evaluate(self, val_dataloader): device = self.device torch.cuda.empty_cache() # self.model.to(device) self.model.eval() mAp_logger = MetricLogger('list') with torch.no_grad(): for (j, batch) in enumerate(val_dataloader): print(f"\rValidation: [{j+1}/{len(val_dataloader)}]", end="") images, targets = batch del batch images = [img.to(device) for img in images] # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] predictions = self.model(images)#, targets) for i, pred in enumerate(predictions): probas = pred["scores"].detach().cpu().numpy() mask = probas > 0.6 preds = pred["boxes"].detach().cpu().numpy()[mask] gts = targets[i]["boxes"].detach().cpu().numpy() score, scores = map_score(gts, preds, thresholds=[.5, .55, .6, .65, .7, .75]) mAp_logger.update(scores) print() return {"validation_mAP_score": mAp_logger.avg} def get_checkpoint(self): self.model.eval() model_state = self.model.state_dict() optimizer_state = self.optimizer.state_dict() checkpoint = {'model_state_dict': model_state, 'optimizer_state_dict': optimizer_state } # if self.lr_scheduler: # scheduler_state = self.lr_scheduler.state_dict() # checkpoint['lr_scheduler_state_dict'] = scheduler_state return checkpoint def load_checkpoint(self, checkpoint): self.model.eval() self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def main(): args = create_argparser().parse_args() dist_util.setup_dist() logger.configure() logger.log("creating model and diffusion...") model, diffusion = create_classifier_and_diffusion( **args_to_dict(args, classifier_and_diffusion_defaults().keys())) model.to(dist_util.dev()) if args.noised: schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, diffusion) resume_step = 0 if args.resume_checkpoint: resume_step = parse_resume_step_from_filename(args.resume_checkpoint) if dist.get_rank() == 0: logger.log( f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" ) model.load_state_dict( dist_util.load_state_dict(args.resume_checkpoint, map_location=dist_util.dev())) # Needed for creating correct EMAs and fp16 parameters. dist_util.sync_params(model.parameters()) mp_trainer = MixedPrecisionTrainer(model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0) model = DDP( model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) logger.log("creating data loader...") data = load_data( data_dir=args.data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, random_crop=True, ) if args.val_data_dir: val_data = load_data( data_dir=args.val_data_dir, batch_size=args.batch_size, image_size=args.image_size, class_cond=True, ) else: val_data = None logger.log(f"creating optimizer...") opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) if args.resume_checkpoint: opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt") logger.log( f"loading optimizer state from checkpoint: {opt_checkpoint}") opt.load_state_dict( dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())) logger.log("training classifier model...") def forward_backward_log(data_loader, prefix="train"): batch, extra = next(data_loader) labels = extra["y"].to(dist_util.dev()) batch = batch.to(dist_util.dev()) # Noisy images if args.noised: t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) batch = diffusion.q_sample(batch, t) else: t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) for i, (sub_batch, sub_labels, sub_t) in enumerate( split_microbatches(args.microbatch, batch, labels, t)): logits = model(sub_batch, timesteps=sub_t) loss = F.cross_entropy(logits, sub_labels, reduction="none") losses = {} losses[f"{prefix}_loss"] = loss.detach() losses[f"{prefix}_acc@1"] = compute_top_k(logits, sub_labels, k=1, reduction="none") losses[f"{prefix}_acc@5"] = compute_top_k(logits, sub_labels, k=5, reduction="none") log_loss_dict(diffusion, sub_t, losses) del losses loss = loss.mean() if loss.requires_grad: if i == 0: mp_trainer.zero_grad() mp_trainer.backward(loss * len(sub_batch) / len(batch)) for step in range(args.iterations - resume_step): logger.logkv("step", step + resume_step) logger.logkv( "samples", (step + resume_step + 1) * args.batch_size * dist.get_world_size(), ) if args.anneal_lr: set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) forward_backward_log(data) mp_trainer.optimize(opt) if val_data is not None and not step % args.eval_interval: with th.no_grad(): with model.no_sync(): model.eval() forward_backward_log(val_data, prefix="val") model.train() if not step % args.log_interval: logger.dumpkvs() if (step and dist.get_rank() == 0 and not (step + resume_step) % args.save_interval): logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) if dist.get_rank() == 0: logger.log("saving model...") save_model(mp_trainer, opt, step + resume_step) dist.barrier()
def main() -> None: global best_loss args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') start_epoch = 0 vcf_reader = VCFReader(args.train_data, args.classification_map, args.chromosome, args.class_hierarchy) vcf_writer = vcf_reader.get_vcf_writer() train_dataset, validation_dataset = vcf_reader.get_datasets( args.validation_split) train_sampler = BatchByLabelRandomSampler(args.batch_size, train_dataset.labels) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler) if args.validation_split != 0: validation_sampler = BatchByLabelRandomSampler( args.batch_size, validation_dataset.labels) validation_loader = DataLoader(validation_dataset, batch_sampler=validation_sampler) kwargs = { 'total_size': vcf_reader.positions.shape[0], 'window_size': args.window_size, 'num_layers': args.layers, 'num_classes': len(vcf_reader.label_encoder.classes_), 'num_super_classes': len(vcf_reader.super_label_encoder.classes_) } model = WindowedMLP(**kwargs) model.to(get_device(args)) optimizer = AdamW(model.parameters(), lr=args.learning_rate) ####### if args.resume_path is not None: if os.path.isfile(args.resume_path): print("=> loading checkpoint '{}'".format(args.resume_path)) checkpoint = torch.load(args.resume_path) if kwargs != checkpoint['model_kwargs']: raise ValueError( 'The checkpoint\'s kwargs don\'t match the ones used to initialize the model' ) if vcf_reader.snps.shape[0] != checkpoint['vcf_writer'].snps.shape[ 0]: raise ValueError( 'The data on which the checkpoint was trained had a different number of snp positions' ) start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume_path, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) ############# if args.validate: validate(validation_loader, model, nn.functional.binary_cross_entropy_with_logits, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, args) return for epoch in range(start_epoch, args.epochs + start_epoch): loss = train(train_loader, model, nn.functional.binary_cross_entropy_with_logits, optimizer, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, epoch, args) if epoch % args.save_freq == 0 or epoch == args.epochs + start_epoch - 1: if args.validation_split != 0: validation_loss = validate( validation_loader, model, nn.functional.binary_cross_entropy_with_logits, len(vcf_reader.label_encoder.classes_), len(vcf_reader.super_label_encoder.classes_), vcf_reader.maf, args) is_best = validation_loss < best_loss best_loss = min(validation_loss, best_loss) else: is_best = loss < best_loss best_loss = min(loss, best_loss) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'model_kwargs': kwargs, 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'vcf_writer': vcf_writer, 'label_encoder': vcf_reader.label_encoder, 'super_label_encoder': vcf_reader.super_label_encoder, 'maf': vcf_reader.maf }, is_best, args.chromosome, args.model_name, args.model_dir)
def train(args, train_dataset, model, tokenizer, writer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) train_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=train_total) if os.path.isfile(os.path.join( args.pretrain_model_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.pretrain_model_path, "scheduler.pt")): optimizer.load_state_dict( torch.load(os.path.join(args.pretrain_model_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.pretrain_model_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) print("***** Running training *****") global_step = 0 steps_trained_in_current_epoch = 0 if os.path.exists(args.pretrain_model_path ) and "checkpoint" in args.pretrain_model_path: global_step = int( args.pretrain_model_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) train_loss, logging_loss = 0.0, 0.0 model.zero_grad() for _ in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "start_positions": batch[3], "end_positions": batch[4] } inputs["token_type_ids"] = (batch[2] if args.model_type in ["bert"] else None) outputs = model(**inputs) loss = outputs[0] writer.add_scalar("Train_loss", loss.item(), step) if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() pbar(step, {'loss': loss.item()}) train_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank == -1: evaluate(args, model, tokenizer, writer) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = (model.module if hasattr(model, "module") else model) model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) tokenizer.save_vocabulary(output_dir) print("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache() return global_step, train_loss / global_step
def run_pretraining(args): if args.parallel and args.local_rank == -1: run_parallel_pretraining(args) return if args.local_rank == -1: if args.cpu: print("CPU!!!") device = torch.device("cpu") else: device = torch.device("cuda") num_workers = 1 worker_index = 0 else: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", args.local_rank) num_workers = torch.distributed.get_world_size() worker_index = torch.distributed.get_rank() if args.local_rank not in (-1, 0): logging.getLogger().setLevel(logging.WARN) logger.info( "Starting pretraining with the following arguments: %s", json.dumps(vars(args), indent=2, sort_keys=True) ) # if args.multilingual: # dataset_dir_list = args.dataset_dir.split(",") # dataset_list = [MedMentionsPretrainingDataset(d) for d in dataset_dir_list] # else: dataset_list = [MedMentionsPretrainingDataset(args.dataset_dir)] bert_config = AutoConfig.from_pretrained(args.bert_model_name) dataset_size = sum([len(d) for d in dataset_list]) num_train_steps_per_epoch = math.ceil(dataset_size / args.batch_size) num_train_steps = math.ceil(dataset_size / args.batch_size * args.num_epochs) print("The Number of Training Steps is: ", num_train_steps) train_batch_size = int(args.batch_size / args.gradient_accumulation_steps / num_workers) entity_vocab = dataset_list[0].entity_vocab config = LukeConfig( entity_vocab_size=entity_vocab.size, bert_model_name=args.bert_model_name, entity_emb_size=args.entity_emb_size, **bert_config.to_dict(), ) model = LukePretrainingModel(config) global_step = args.global_step batch_generator_args = dict( batch_size=train_batch_size, masked_lm_prob=args.masked_lm_prob, masked_entity_prob=args.masked_entity_prob, whole_word_masking=args.whole_word_masking, unmasked_word_prob=args.unmasked_word_prob, random_word_prob=args.random_word_prob, unmasked_entity_prob=args.unmasked_entity_prob, random_entity_prob=args.random_entity_prob, mask_words_in_entity_span=args.mask_words_in_entity_span, num_workers=num_workers, worker_index=worker_index, skip=global_step * args.batch_size, ) # if args.multilingual: # data_size_list = [len(d) for d in dataset_list] # batch_generator = MultilingualBatchGenerator( # dataset_dir_list, data_size_list, args.sampling_smoothing, **batch_generator_args, # ) # else: batch_generator = LukePretrainingBatchGenerator(args.dataset_dir, **batch_generator_args) logger.info("Model configuration: %s", config) if args.fix_bert_weights: for param in model.parameters(): param.requires_grad = False for param in model.entity_embeddings.parameters(): param.requires_grad = True for param in model.entity_predictions.parameters(): param.requires_grad = True model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "LayerNorm.weight"] optimizer_parameters = [ { "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] if args.original_adam: optimizer = AdamW( optimizer_parameters, lr=args.learning_rate, betas=(args.adam_b1, args.adam_b2), eps=args.adam_eps, ) else: optimizer = LukeAdamW( optimizer_parameters, lr=args.learning_rate, betas=(args.adam_b1, args.adam_b2), eps=args.adam_eps, grad_avg_device=torch.device("cpu") if args.grad_avg_on_cpu else device, ) if args.fp16: from apex import amp if args.fp16_opt_level == "O2": model, optimizer = amp.initialize( model, optimizer, opt_level=args.fp16_opt_level, master_weights=args.fp16_master_weights, min_loss_scale=args.fp16_min_loss_scale, max_loss_scale=args.fp16_max_loss_scale, ) else: model, optimizer = amp.initialize( model, optimizer, opt_level=args.fp16_opt_level, min_loss_scale=args.fp16_min_loss_scale, max_loss_scale=args.fp16_max_loss_scale, ) if args.model_file is None: bert_model = AutoModelForPreTraining.from_pretrained(args.bert_model_name) bert_state_dict = bert_model.state_dict() model.load_bert_weights(bert_state_dict) else: model_state_dict = torch.load(args.model_file, map_location="cpu") model.load_state_dict(model_state_dict, strict=False) if args.optimizer_file is not None: optimizer.load_state_dict(torch.load(args.optimizer_file, map_location="cpu")) if args.amp_file is not None: amp.load_state_dict(torch.load(args.amp_file, map_location="cpu")) if args.lr_schedule == "warmup_constant": scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) elif args.lr_schedule == "warmup_linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_steps ) print(f"Scheduler data: Warmup steps: {args.warmup_steps}; total training steps: {num_train_steps}") else: raise RuntimeError(f"Invalid scheduler: {args.lr_schedule}") if args.scheduler_file is not None: scheduler.load_state_dict(torch.load(args.scheduler_file, map_location="cpu")) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True, ) model.train() if args.local_rank == -1 or worker_index == 0: entity_vocab.save(os.path.join(args.output_dir, ENTITY_VOCAB_FILE)) metadata = dict( model_config=config.to_dict(), max_seq_length=dataset_list[0].max_seq_length, max_entity_length=dataset_list[0].max_entity_length, max_mention_length=dataset_list[0].max_mention_length, arguments=vars(args), ) with open(os.path.join(args.output_dir, "metadata.json"), "w") as metadata_file: json.dump(metadata, metadata_file, indent=2, sort_keys=True) def save_model(model, suffix): if args.local_rank != -1: model = model.module model_file = f"model_{suffix}.bin" torch.save(model.state_dict(), os.path.join(args.output_dir, model_file)) optimizer_file = f"optimizer_{suffix}.bin" torch.save(optimizer.state_dict(), os.path.join(args.output_dir, optimizer_file)) scheduler_file = f"scheduler_{suffix}.bin" torch.save(scheduler.state_dict(), os.path.join(args.output_dir, scheduler_file)) metadata = dict( global_step=global_step, model_file=model_file, optimizer_file=optimizer_file, scheduler_file=scheduler_file ) if args.fp16: amp_file = f"amp_{suffix}.bin" torch.save(amp.state_dict(), os.path.join(args.output_dir, amp_file)) metadata["amp_file"] = amp_file with open(os.path.join(args.output_dir, f"metadata_{suffix}.json"), "w") as f: json.dump(metadata, f, indent=2, sort_keys=True) if args.local_rank == -1 or worker_index == 0: summary_writer = SummaryWriter(args.log_dir) pbar = tqdm(total=num_train_steps, initial=global_step) tr_loss = 0 accumulation_count = 0 results = [] prev_error = False prev_step_time = time.time() prev_save_time = time.time() for batch in batch_generator.generate_batches(): try: batch = {k: torch.from_numpy(v).to(device) for k, v in batch.items()} result = model(**batch) loss = result["loss"] result = {k: v.to("cpu").detach().numpy() for k, v in result.items()} if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps def maybe_no_sync(): if ( hasattr(model, "no_sync") and num_workers > 1 and accumulation_count + 1 != args.gradient_accumulation_steps ): return model.no_sync() else: return contextlib.ExitStack() with maybe_no_sync(): if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() except RuntimeError: if prev_error: logger.exception("Consecutive errors have been observed. Exiting...") raise logger.exception("An unexpected error has occurred. Skipping a batch...") prev_error = True loss = None torch.cuda.empty_cache() continue accumulation_count += 1 prev_error = False tr_loss += loss.item() loss = None results.append(result) if accumulation_count == args.gradient_accumulation_steps: if args.max_grad_norm != 0.0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() model.zero_grad() accumulation_count = 0 summary = {} # line used to be, changed due to backwards compat but it should've worked? # summary["learning_rate"] = max(scheduler.get_lr()) summary["learning_rate"] = max(scheduler.get_lr()) summary["loss"] = tr_loss tr_loss = 0 current_time = time.time() summary["batch_run_time"] = current_time - prev_step_time prev_step_time = current_time for name in ("masked_lm", "masked_entity"): try: summary[name + "_loss"] = np.concatenate([r[name + "_loss"].flatten() for r in results]).mean() correct = np.concatenate([r[name + "_correct"].flatten() for r in results]).sum() total = np.concatenate([r[name + "_total"].flatten() for r in results]).sum() if total > 0: summary[name + "_acc"] = correct / total except KeyError: continue results = [] if args.local_rank == -1 or worker_index == 0: for (name, value) in summary.items(): summary_writer.add_scalar(name, value, global_step) desc = ( f"epoch: {int(global_step / num_train_steps_per_epoch)} " f'loss: {summary["loss"]:.4f} ' f'time: {datetime.datetime.now().strftime("%H:%M:%S")}' ) pbar.set_description(desc) pbar.update() global_step += 1 if args.local_rank == -1 or worker_index == 0: if global_step == num_train_steps: # save the final model save_model(model, f"epoch{args.num_epochs}") time.sleep(60) elif global_step % num_train_steps_per_epoch == 0: # save the model at each epoch epoch = int(global_step / num_train_steps_per_epoch) save_model(model, f"epoch{epoch}") if args.save_interval_sec and time.time() - prev_save_time > args.save_interval_sec: save_model(model, f"step{global_step:07}") prev_save_time = time.time() if args.save_interval_steps and global_step % args.save_interval_steps == 0: save_model(model, f"step{global_step}") if global_step == num_train_steps: break if args.local_rank == -1 or worker_index == 0: summary_writer.close()
def run_training(args, ls): ls.print('Training started: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # Misc setup os.makedirs(args.model_dir, exist_ok=True) assert len(args.cnn_filters)%2 == 0 args.cnn_filters = list(zip(args.cnn_filters[:-1:2], args.cnn_filters[1::2])) # Load the vocabs vocabs = get_vocabs(os.path.join(args.model_dir, args.vocab_dir)) bert_tokenizer = None if args.with_bert: bert_tokenizer = BertEncoderTokenizer.from_pretrained(args.bert_path, do_lower_case=False) vocabs['bert_tokenizer'] = bert_tokenizer for name in vocabs: if name == 'bert_tokenizer': continue ls.print('Vocab %-20s size %5d coverage %.3f' % (name, vocabs[name].size, vocabs[name].coverage)) # Setup BERT encoder bert_encoder = None if args.with_bert: bert_encoder = BertEncoder.from_pretrained(args.bert_path) for p in bert_encoder.parameters(): p.requires_grad = False # Device and random setup torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) device = torch.device(args.device) # Create the model ls.print('Setting up the model') model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim, device, args.pretrained_file, bert_encoder,) model = model.to(device) # Optimizer and weight decay params weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{'params':weight_decay_params, 'weight_decay':1e-4}, {'params':no_weight_decay_params, 'weight_decay':0.}] optimizer = AdamW(grouped_params, 1., betas=(0.9, 0.999), eps=1e-6) # Re-load an existing model if requested used_batches = 0 batches_acm = 0 if args.resume_ckpt: ls.print('Resuming from checkpoint', args.resume_ckpt) ckpt = torch.load(args.resume_ckpt) model.load_state_dict(ckpt['model']) if ckpt.get('optimizer', {}): optimizer.load_state_dict(ckpt['optimizer']) else: ls.print('No optimizer state saved in checkpoint, using default initial optimizer') batches_acm = ckpt['batches_acm'] start_epoch = ckpt['epoch'] + 1 del ckpt else: start_epoch = 1 # don't start at 0 # Load data ls.print('Loading training data') train_data = DataLoader(vocabs, args.train_data, args.train_batch_size, for_train=True) train_data.set_unk_rate(args.unk_rate) # Train ls.print('Training') epoch, loss_avg, concept_loss_avg, arc_loss_avg, rel_loss_avg = 0, 0, 0, 0, 0 for epoch in range(start_epoch, args.epochs+1): st = time.time() for batch in train_data: model.train() batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss, graph_arc_loss = model(batch) loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update loss_value = loss.item() concept_loss_value = concept_loss.item() arc_loss_value = arc_loss.item() rel_loss_value = rel_loss.item() loss_avg = loss_avg * args.batches_per_update * 0.8 + 0.2 * loss_value concept_loss_avg = concept_loss_avg * 0.8 + 0.2 * concept_loss_value arc_loss_avg = arc_loss_avg * 0.8 + 0.2 * arc_loss_value rel_loss_avg = rel_loss_avg * 0.8 + 0.2 * rel_loss_value loss.backward() used_batches += 1 if not (used_batches % args.batches_per_update == -1 % args.batches_per_update): continue batches_acm += 1 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) lr = update_lr(optimizer, args.lr_scale, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() # Summary at the end of the epoch dur = time.time() - st ls.print('Epoch %4d, Batch %5d, LR %.6f, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f, duration %.1f seconds' % (epoch, batches_acm, lr, concept_loss_avg, arc_loss_avg, rel_loss_avg, dur)) # Evaluate and save the data every so often if (epoch>args.skip_evals or args.resume_ckpt is not None) and epoch % args.eval_every == 0: model.eval() ls.print('Evaluating and saving the model') fname = '%s/epoch%d.pt'%(args.model_dir, epoch) optim = optimizer.state_dict() if args.save_optimizer else {} torch.save({'args':vars(args), 'model':model.state_dict(), 'batches_acm': batches_acm, 'optimizer': optim, 'epoch':epoch}, fname) try: out_fn = 'epoch%d.pt.dev_generated' % (epoch) inference = Inference.build_from_model(model, vocabs) f_score, ctr = inference.reparse_annotated_file('.', args.dev_data, args.model_dir, out_fn, print_summary=False) ls.print('Smatch F: %.3f. Wrote %d AMR graphs to %s' % \ (f_score, ctr, os.path.join(args.model_dir, out_fn))) except: ls.print('Exception during generation') traceback.print_exc() model.train() # End time-stamp ls.print('Training finished: ' + datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
class Trainer(): def __init__(self, alphabets_, list_ngram): self.vocab = Vocab(alphabets_) self.synthesizer = SynthesizeData(vocab_path="") self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split( list_ngram, test_size=0.1) print("Loaded data!!!") print("Total training samples: ", len(self.list_ngrams_train)) print("Total valid samples: ", len(self.list_ngrams_valid)) INPUT_DIM = self.vocab.__len__() OUTPUT_DIM = self.vocab.__len__() self.device = DEVICE self.num_iters = NUM_ITERS self.beamsearch = BEAM_SEARCH self.batch_size = BATCH_SIZE self.print_every = PRINT_PER_ITER self.valid_every = VALID_PER_ITER self.checkpoint = CHECKPOINT self.export_weights = EXPORT self.metrics = MAX_SAMPLE_VALID logger = LOG if logger: self.logger = Logger(logger) self.iter = 0 self.model = Seq2Seq(input_dim=INPUT_DIM, output_dim=OUTPUT_DIM, encoder_embbeded=ENC_EMB_DIM, decoder_embedded=DEC_EMB_DIM, encoder_hidden=ENC_HID_DIM, decoder_hidden=DEC_HID_DIM, encoder_dropout=ENC_DROPOUT, decoder_dropout=DEC_DROPOUT) self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, pct_start=PCT_START, max_lr=MAX_LR) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) self.train_gen = self.data_gen(self.list_ngrams_train, self.synthesizer, self.vocab, is_train=True) self.valid_gen = self.data_gen(self.list_ngrams_valid, self.synthesizer, self.vocab, is_train=False) self.train_losses = [] # to device self.model.to(self.device) self.criterion.to(self.device) def train_test_split(self, list_phrases, test_size=0.1): list_phrases = list_phrases train_idx = int(len(list_phrases) * (1 - test_size)) list_phrases_train = list_phrases[:train_idx] list_phrases_valid = list_phrases[train_idx:] return list_phrases_train, list_phrases_valid def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True): dataset = AutoCorrectDataset(list_ngrams_np, transform_noise=synthesizer, vocab=vocab, maxlen=MAXLEN) shuffle = True if is_train else False gen = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle, drop_last=False) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose( 1, 0) # batch x src_len -> src_len x batch outputs = self.model( src, tgt) # src : src_len x B, outpus : B x tgt_len x vocab # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) # flatten(0, 1) tgt_output = tgt.transpose(0, 1).reshape( -1) # flatten() # tgt: tgt_len xB , need convert to B x tgt_len loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item def train(self): print("Begin training from iter: ", self.iter) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = -1 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.iter % self.valid_every == 0: val_loss, preds, actuals, inp_sents = self.validate() acc_full_seq, acc_per_char, cer = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format( self.iter, val_loss, acc_full_seq, acc_per_char, cer) print(info) print("--- Sentence predict ---") for pred, inp, label in zip(preds, inp_sents, actuals): infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format( pred, inp, label) print(infor_predict) self.logger.log(infor_predict) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq self.save_checkpoint(self.checkpoint) def validate(self): self.model.eval() total_loss = [] max_step = self.metrics / self.batch_size with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) src, tgt = batch['src'], batch['tgt'] src, tgt = src.transpose(1, 0), tgt.transpose(1, 0) outputs = self.model(src, tgt, 0) # turn off teaching force outputs = outputs.flatten(0, 1) tgt_output = tgt.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) preds, actuals, inp_sents, probs = self.predict(5) del outputs del loss if step > max_step: break total_loss = np.mean(total_loss) self.model.train() return total_loss, preds[:3], actuals[:3], inp_sents[:3] def predict(self, sample=None): pred_sents = [] actual_sents = [] inp_sents = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['src'], self.model) prob = None else: translated_sentence, prob = translate(batch['src'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt'].tolist()) inp_sent = self.vocab.batch_decode(batch['src'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) inp_sents.extend(inp_sent) if sample is not None and len(pred_sents) > sample: break return pred_sents, actual_sents, inp_sents, prob def precision(self, sample=None): pred_sents, actual_sents, _, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') cer = compute_accuracy(actual_sents, pred_sents, mode='CER') return acc_full_seq, acc_per_char, cer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files, probs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) n += 1 if n >= sample: return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': self.scheduler.state_dict() } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'.format( name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): src = batch['src'].to(self.device, non_blocking=True) tgt = batch['tgt'].to(self.device, non_blocking=True) batch = {'src': src, 'tgt': tgt} return batch
def train_loop(new_data, old_data, stats=None): ## prep dataloaders ## X, y = new_data['X'], new_data['y'] dataloader_new = DataLoader(list(zip(X, y)), batch_size=1, shuffle=True) dataloader_old = DataLoader(old_data, batch_size=1, shuffle=True) del X, y ## optimizer and scheduler ## # calculate total steps opts.gradient_accumulation_steps, opts.num_train_epochs = 64, 1 t_total = len(dataloader_old ) // opts.gradient_accumulation_steps * opts.num_train_epochs ## set up optimizers and schedulers ## with torch.no_grad(): fast_group = flatten([[p[act_tok], p[start_tok], p[p1_tok], p[p2_tok]] for n, p in model.named_parameters() if n == 'transformer.wte.weight' ]) #['transformer.wte.weight'] freeze_group = [ p[:start_tok] for n, p in model.named_parameters() if n == 'transformer.wte.weight' ] #['transformer.wte.weight'] slow_group = [ p for n, p in model.named_parameters() if n == 'transformer.wpe.weight' ] normal_group = [ p for n, p in model.named_parameters() if n not in ('transformer.wte.weight', 'transformer.wpe.weight') ] # different learn rates for different param groups optimizer_grouped_parameters = [{ "params": fast_group, 'lr': 5e-4 }, { "params": freeze_group, 'lr': 1e-8 }, { "params": slow_group, 'lr': 1e-6 }, { "params": normal_group, 'lr': opts.lr }] optimizer = AdamW(optimizer_grouped_parameters, lr=opts.lr, eps=opts.eps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=opts.warmup_steps, num_training_steps=t_total) # loading optimizer settings if (opts.model_name_or_path and os.path.isfile( os.path.join(opts.model_name_or_path, "train_optimizer.pt")) and os.path.isfile( os.path.join(opts.model_name_or_path, "train_scheduler.pt"))): # load optimizer and scheduler states optimizer.load_state_dict( torch.load( os.path.join(opts.model_name_or_path, "train_optimizer.pt"))) scheduler.load_state_dict( torch.load( os.path.join(opts.model_name_or_path, "train_scheduler.pt"))) # track stats if stats is not None: global_step = max(stats.keys()) epochs_trained = global_step // (len(dataloader_old) // opts.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(dataloader_old) // opts.gradient_accumulation_steps) print("Resuming Training ... ") else: stats = {} global_step, epochs_trained, steps_trained_in_current_epoch = 0, 0, 0 tr_loss, logging_loss = 0.0, 0.0 tr_loss_old, logging_loss_old = 0.0, 0.0 model.zero_grad() print("Re-sizing model ... ") model.resize_token_embeddings(len(tokenizer)) # training mode model.train() data_iter_new = iter(dataloader_new) data_iter_old = iter(dataloader_old) for epoch in range(epochs_trained, opts.num_train_epochs): for step in range(len(dataloader_old)): if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 batch = data_iter_old.next() continue ### new data step ### try: batch = data_iter_new.next() except: X, y = new_data['X'], new_data['y'] dataloader_new = DataLoader(list(zip(X, y)), batch_size=1, shuffle=True) del X, y data_iter_new = iter(dataloader_new) batch = data_iter_new.next() new_loss = fit_on_batch(batch) del batch tr_loss += new_loss.item() ## old data step ### try: batch = data_iter_old.next() except: data_iter_old = iter(dataloader_old) batch = data_iter_old.next() old_loss = fit_on_batch(batch) del batch tr_loss_old += old_loss.item() # gradient accumulation if (step + 1) % opts.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opts.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 # reporting if global_step % opts.logging_steps == 0: stats[global_step] = { 'persona_loss': (tr_loss - logging_loss) / opts.logging_steps, 'ctrl_loss': (tr_loss_old - logging_loss_old) / opts.logging_steps, 'train_lr': scheduler.get_last_lr()[-1] } logging_loss = tr_loss logging_loss_old = tr_loss_old print( 'Epoch: %d | Iter: [%d/%d] | new_loss: %.3f | old_loss: %.3f | lr: %s ' % (epoch, step, len(dataloader_old), stats[global_step]['persona_loss'], stats[global_step]['ctrl_loss'], str(stats[global_step]['train_lr']))) if global_step % opts.save_steps == 0: print("Saving stuff ... ") checkpoint(model, tokenizer, optimizer, scheduler, stats, title="train_") plot_losses(stats, title='persona_loss') plot_losses(stats, title='ctrl_loss') plot_losses(stats, title='train_lr') print("Done.") return stats
def optim_config(args: dict, model): # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] bert_param_optimizer = list(model.bert.named_parameters()) crf_param_optimizer = list(model.crf.named_parameters()) linear_param_optimizer = list(model.classifier.named_parameters()) optimizer_grouped_parameters = [{ 'params': [ p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args['weight_decay'], 'lr': args['learning_rate'] }, { 'params': [ p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args['learning_rate'] }, { 'params': [ p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args['weight_decay'], 'lr': args['crf_learning_rate'] }, { 'params': [p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args['crf_learning_rate'] }, { 'params': [ p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args['weight_decay'], 'lr': args['crf_learning_rate'] }, { 'params': [ p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0, 'lr': args['crf_learning_rate'] }] optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon']) if os.path.isfile(os.path.join(args['model_name_or_path'], "optimizer.pt")): # Load in optimizer states optimizer.load_state_dict( torch.load(os.path.join(args['model_name_or_path'], "optimizer.pt"))) return optimizer
class Seq2seqKpGen(object): """High level model that handles intializing the underlying network architecture, saving, updating examples, and predicting examples. """ # -------------------------------------------------------------------------- # Initialization # -------------------------------------------------------------------------- def __init__(self, args, word_dict, state_dict=None): # Book-keeping. self.args = args self.word_dict = word_dict self.args.vocab_size = len(word_dict) self.updates = 0 self.network = Sequence2Sequence(self.args, self.word_dict) if state_dict: self.network.load_state_dict(state_dict) def activate_fp16(self): if not hasattr(self, 'optimizer'): self.network.half() # for testing only return try: global amp from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") # https://github.com/NVIDIA/apex/issues/227 assert self.optimizer is not None self.network, self.optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.args.fp16_opt_level) def init_optimizer(self, optim_state=None, sched_state=None): def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.args.weight_decay, }, {"params": [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate) self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps) if optim_state: self.optimizer.load_state_dict(optim_state) if self.args.device.type == 'cuda': for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.args.device) if sched_state: self.scheduler.load_state_dict(sched_state) # -------------------------------------------------------------------------- # Learning # -------------------------------------------------------------------------- def update(self, ex): """Forward a batch of examples; step the optimizer to update weights.""" if not self.optimizer: raise RuntimeError('No optimizer set.') # Train mode self.network.train() source_map, alignment = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) target_rep = ex['target_rep'].to(self.args.device) target_len = ex['target_len'].to(self.args.device) # Run forward ml_loss, loss_per_token = self.network(source=source_rep, source_len=source_len, target=target_rep, target_len=target_len, src_map=source_map, alignment=alignment) loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss if self.args.fp16: global amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping) else: loss.backward() clip_grad_norm_(self.network.parameters(), self.args.grad_clipping) self.updates += 1 self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.optimizer.zero_grad() loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token loss_per_token = loss_per_token.item() loss_per_token = 10 if loss_per_token > 10 else loss_per_token perplexity = math.exp(loss_per_token) return { 'ml_loss': loss.item(), 'perplexity': perplexity } # -------------------------------------------------------------------------- # Prediction # -------------------------------------------------------------------------- def predict(self, ex, replace_unk=False): """Forward a batch of examples only to get predictions. Args: ex: the batch examples replace_unk: replace `unk` tokens while generating predictions src_raw: raw source (passage); required to replace `unk` term Output: predictions: #batch predicted sequences """ def convert_text_to_string(text): """ Converts a sequence of tokens (string) in a single string. """ out_string = text.replace(" ##", "").strip() return out_string self.network.eval() source_map, alignment = None, None blank, fill = None, None if self.args.copy_attn: source_map = make_src_map(ex['src_map']).to(self.args.device) alignment = align(ex['alignment']).to(self.args.device) blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab']) source_rep = ex['source_rep'].to(self.args.device) source_len = ex['source_len'].to(self.args.device) decoder_out = self.network(source=source_rep, source_len=source_len, target=None, target_len=None, src_map=source_map, alignment=alignment, max_len=self.args.max_tgt_len, tgt_dict=self.word_dict, blank=blank, fill=fill, source_vocab=ex['src_vocab']) dec_probs = torch.exp(decoder_out['dec_log_probs']) predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs, self.word_dict, ex['src_vocab']) if replace_unk: for i in range(len(predictions)): enc_dec_attn = decoder_out['attentions'][i] if self.args.model_type == 'transformer': # tgt_len x num_heads x src_len assert enc_dec_attn.dim() == 3 enc_dec_attn = enc_dec_attn.mean(1) predictions[i] = replace_unknown(predictions[i], enc_dec_attn, src_raw=ex['source'][i].tokens) for bidx in range(ex['batch_size']): for i in range(len(predictions[bidx])): if predictions[bidx][i] == constants.KP_SEP: scores[bidx][i] = constants.KP_SEP elif predictions[bidx][i] == constants.PRESENT_EOS: scores[bidx][i] = constants.PRESENT_EOS else: assert isinstance(scores[bidx][i], float) scores[bidx][i] = str(scores[bidx][i]) predictions = [' '.join(item) for item in predictions] scores = [' '.join(item) for item in scores] present_kps = [] absent_kps = [] present_kp_scores = [] absent_kp_scores = [] for bidx in range(ex['batch_size']): keyphrases = predictions[bidx].split(constants.PRESENT_EOS) kp_scores = scores[bidx].split(constants.PRESENT_EOS) pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1]) pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1]) akps = keyphrases[-1] akp_scores = kp_scores[-1] pre_kps = [] pre_kp_scores = [] for pkp, pkp_s in zip(pkps.split(constants.KP_SEP), pkp_scores.split(constants.KP_SEP)): pkp = pkp.strip() if pkp: pre_kps.append(convert_text_to_string(pkp)) t_scores = [float(i) for i in pkp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) pre_kp_scores.append(_score) present_kps.append(pre_kps) present_kp_scores.append(pre_kp_scores) abs_kps = [] abs_kp_scores = [] for akp, akp_s in zip(akps.split(constants.KP_SEP), akp_scores.split(constants.KP_SEP)): akp = akp.strip() if akp: abs_kps.append(convert_text_to_string(akp)) t_scores = [float(i) for i in akp_s.strip().split()] _score = np.prod(t_scores) / len(t_scores) abs_kp_scores.append(_score) absent_kps.append(abs_kps) absent_kp_scores.append(abs_kp_scores) return { 'present_kps': present_kps, 'absent_kps': absent_kps, 'present_kp_scores': present_kp_scores, 'absent_kp_scores': absent_kp_scores } # -------------------------------------------------------------------------- # Saving and loading # -------------------------------------------------------------------------- def save(self, filename): network = self.network.module if hasattr(self.network, "module") \ else self.network state_dict = copy.copy(network.state_dict()) params = { 'state_dict': state_dict, 'word_dict': self.word_dict, 'args': self.args, } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') def checkpoint(self, filename, epoch): network = self.network.module if hasattr(self.network, "module") \ else self.network params = { 'state_dict': network.state_dict(), 'word_dict': self.word_dict, 'args': self.args, 'epoch': epoch, 'updates': self.updates, 'optim_dict': self.optimizer.state_dict(), 'sched_dict': self.scheduler.state_dict(), } try: torch.save(params, filename) except BaseException: logger.warning('WARN: Saving failed... continuing anyway.') @staticmethod def load(filename, new_args=None): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] args = saved_params['args'] if new_args: args = override_model_args(args, new_args) return Seq2seqKpGen(args, word_dict, state_dict) @staticmethod def load_checkpoint(filename): logger.info('Loading model %s' % filename) saved_params = torch.load( filename, map_location=lambda storage, loc: storage ) word_dict = saved_params['word_dict'] state_dict = saved_params['state_dict'] epoch = saved_params['epoch'] updates = saved_params['updates'] optim_dict = saved_params['optim_dict'] sched_dict = saved_params['sched_dict'] args = saved_params['args'] model = Seq2seqKpGen(args, word_dict, state_dict) model.updates = updates model.init_optimizer(optim_dict, sched_dict) return model, epoch # -------------------------------------------------------------------------- # Runtime # -------------------------------------------------------------------------- def to(self, device): self.network = self.network.to(device) def parallelize(self): self.network = torch.nn.DataParallel(self.network)
class Trainer(): def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.train_lmdb = config['dataset']['train_lmdb'] self.valid_lmdb = config['dataset']['valid_lmdb'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.image_aug = config['aug']['image_aug'] self.masked_language_model = config['aug']['masked_language_model'] self.metrics = config['trainer']['metrics'] self.is_padding = config['dataset']['is_padding'] self.tensorboard_dir = config['monitor']['log_dir'] if not os.path.exists(self.tensorboard_dir): os.makedirs(self.tensorboard_dir, exist_ok=True) self.writer = SummaryWriter(self.tensorboard_dir) # LOGGER self.logger = Logger(config['monitor']['log_dir']) self.logger.info(config) self.iter = 0 self.best_acc = 0 self.scheduler = None self.is_finetuning = config['trainer']['is_finetuning'] if self.is_finetuning: self.logger.info("Finetuning model ---->") if self.model.seq_modeling == 'crnn': self.optimizer = Adam(lr=0.0001, params=self.model.parameters(), betas=(0.5, 0.999)) else: self.optimizer = AdamW(lr=0.0001, params=self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) else: self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, total_steps=self.num_iters, **config['optimizer']) if self.model.seq_modeling == 'crnn': self.criterion = torch.nn.CTCLoss(self.vocab.pad, zero_infinity=True) else: self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) # Pretrained model if config['trainer']['pretrained']: self.load_weights(config['trainer']['pretrained']) self.logger.info("Loaded trained model from: {}".format( config['trainer']['pretrained'])) # Resume elif config['trainer']['resume_from']: self.load_checkpoint(config['trainer']['resume_from']) for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(torch.device(self.device)) self.logger.info("Resume training from {}".format( config['trainer']['resume_from'])) # DATASET transforms = None if self.image_aug: transforms = augmentor train_lmdb_paths = [ os.path.join(self.data_root, lmdb_path) for lmdb_path in self.train_lmdb ] self.train_gen = self.data_gen( lmdb_paths=train_lmdb_paths, data_root=self.data_root, annotation=self.train_annotation, masked_language_model=self.masked_language_model, transform=transforms, is_train=True) if self.valid_annotation: self.valid_gen = self.data_gen( lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)], data_root=self.data_root, annotation=self.valid_annotation, masked_language_model=False) self.train_losses = [] self.logger.info("Number batch samples of training: %d" % len(self.train_gen)) self.logger.info("Number batch samples of valid: %d" % len(self.valid_gen)) config_savepath = os.path.join(self.tensorboard_dir, "config.yml") if not os.path.exists(config_savepath): self.logger.info("Saving config file at: %s" % config_savepath) Cfg(config).save(config_savepath) def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() # LOSS loss = self.step(batch) total_loss += loss self.train_losses.append((self.iter, loss)) total_gpu_time += time.time() - start if self.iter % self.print_every == 0: info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) lastest_loss = total_loss / self.print_every total_loss = 0 total_loader_time = 0 total_gpu_time = 0 self.logger.info(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_time = time.time() val_loss = self.validate() acc_full_seq, acc_per_char, wer = self.precision(self.metrics) self.logger.info("Iter: {:06d}, start validating".format( self.iter)) info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char, wer, time.time() - val_time) self.logger.info(info) if acc_full_seq > self.best_acc: self.save_weights(self.tensorboard_dir + "/best.pt") self.best_acc = acc_full_seq self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format( self.iter, self.best_acc)) filename = 'last.pt' filepath = os.path.join(self.tensorboard_dir, filename) self.logger.info("Save checkpoint %s" % filename) self.save_checkpoint(filepath) log_loss = {'train loss': lastest_loss, 'val loss': val_loss} self.writer.add_scalars('Loss', log_loss, self.iter) self.writer.add_scalar('WER', wer, self.iter) def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] probs_sents = [] imgs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) imgs_sents.extend(batch['img']) img_files.extend(batch['filenames']) probs_sents.extend(prob) # Visualize in tensorboard if idx == 0: try: num_samples = self.config['monitor']['num_samples'] fig = plt.figure(figsize=(12, 15)) imgs_samples = imgs_sents[:num_samples] preds_samples = pred_sents[:num_samples] actuals_samples = actual_sents[:num_samples] probs_samples = probs_sents[:num_samples] for id_img in range(len(imgs_samples)): img = imgs_samples[id_img] img = img.permute(1, 2, 0) img = img.cpu().detach().numpy() ax = fig.add_subplot(num_samples, 1, id_img + 1, xticks=[], yticks=[]) plt.imshow(img) ax.set_title( "LB: {} \n Pred: {:.4f}-{}".format( actuals_samples[id_img], probs_samples[id_img], preds_samples[id_img]), color=('green' if actuals_samples[id_img] == preds_samples[id_img] else 'red'), fontdict={ 'fontsize': 18, 'fontweight': 'medium' }) self.writer.add_figure('predictions vs. actuals', fig, global_step=self.iter) except Exception as error: print(error) continue if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files, probs_sents, imgs_sents def precision(self, sample=None, measure_time=True): t1 = time.time() pred_sents, actual_sents, _, _, _ = self.predict(sample=sample) time_predict = time.time() - t1 sensitive_case = self.config['predictor']['sensitive_case'] acc_full_seq = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='per_char') wer = compute_accuracy(actual_sents, pred_sents, sensitive_case, mode='wer') if measure_time: print("Time: {:.4f}".format(time_predict / len(actual_sents))) return acc_full_seq, acc_per_char, wer def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16, save_fig=False): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] probs = [probs[i] for i in wrongs] imgs = [imgs[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} ncols = 5 nrows = int(math.ceil(len(img_files) / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15)) for vis_idx in range(0, len(img_files)): row = vis_idx // ncols col = vis_idx % ncols pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] prob = probs[vis_idx] img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy() ax[row, col].imshow(img) ax[row, col].set_title( "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format( pred_sent, actual_sent, prob), fontname=fontname, color='r' if pred_sent != actual_sent else 'g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) plt.subplots_adjust() if save_fig: fig.savefig('vis_prediction.png') plt.show() def log_prediction(self, sample=16, csv_file='model.csv'): pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample) save_predictions(csv_file, pred_sents, actual_sents, img_files) def vis_data(self, sample=20): ncols = 5 nrows = int(math.ceil(sample / ncols)) fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12)) num_plots = 0 for idx, batch in enumerate(self.train_gen): for vis_idx in range(self.batch_size): row = num_plots // ncols col = num_plots % ncols img = batch['img'][vis_idx].numpy().transpose(1, 2, 0) sent = self.vocab.decode( batch['tgt_input'].T[vis_idx].tolist()) ax[row, col].imshow(img) ax[row, col].set_title("Label: {: <2}".format(sent), fontsize=16, color='g') ax[row, col].get_xaxis().set_ticks([]) ax[row, col].get_yaxis().set_ticks([]) num_plots += 1 if num_plots >= sample: plt.subplots_adjust() fig.savefig('vis_dataset.png') return def load_checkpoint(self, filename): checkpoint = torch.load(filename) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] if self.scheduler is not None: self.scheduler.load_state_dict(checkpoint['scheduler']) self.best_acc = checkpoint['best_acc'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'scheduler': None if self.scheduler is None else self.scheduler.state_dict(), 'best_acc': self.best_acc } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) if self.is_checkpoint(state_dict): self.model.load_state_dict(state_dict['state_dict']) else: for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape, required {} but found {}'. format(name, param.shape, state_dict[name].shape)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def is_checkpoint(self, checkpoint): try: checkpoint['state_dict'] except: return False else: return True def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'], 'labels_len': batch['labels_len'] } return batch def data_gen(self, lmdb_paths, data_root, annotation, masked_language_model=True, transform=None, is_train=False): datasets = [] for lmdb_path in lmdb_paths: dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width'], separate=self.config['dataset']['separate'], batch_size=self.batch_size, is_padding=self.is_padding) datasets.append(dataset) if len(self.train_lmdb) > 1: dataset = torch.utils.data.ConcatDataset(datasets) if self.is_padding: sampler = None else: sampler = ClusterRandomSampler(dataset, self.batch_size, True) collate_fn = Collator(masked_language_model) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=is_train, drop_last=self.model.seq_modeling == 'crnn', **self.config['dataloader']) return gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) if self.model.seq_modeling == 'crnn': length = batch['labels_len'] preds_size = torch.autograd.Variable( torch.IntTensor([outputs.size(0)] * self.batch_size)) loss = self.criterion(outputs, tgt_output, preds_size, length) else: outputs = outputs.view( -1, outputs.size(2)) # flatten(0, 1) # B*S x N_class tgt_output = tgt_output.view(-1) # flatten() # B*S loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() if not self.is_finetuning: self.scheduler.step() loss_item = loss.item() return loss_item def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def gen_pseudo_labels(self, outfile=None): pred_sents = [] img_files = [] probs_sents = [] for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)): batch = self.batch_to_device(batch) if self.model.seq_modeling != 'crnn': if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) prob = None else: translated_sentence, prob = translate( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist()) else: translated_sentence, prob = translate_crnn( batch['img'], self.model) pred_sent = self.vocab.batch_decode( translated_sentence.tolist(), crnn=True) pred_sents.extend(pred_sent) img_files.extend(batch['filenames']) probs_sents.extend(prob) assert len(pred_sents) == len(img_files) and len(img_files) == len( probs_sents) with open(outfile, 'w', encoding='utf-8') as f: for anno in zip(img_files, pred_sents, probs_sents): f.write('||||'.join([anno[0], anno[1], str(float(anno[2]))]) + '\n')
def main(): args = parseArguments() os.makedirs(args.modelDir, exist_ok=True) checkpointDir = os.path.join(args.modelDir, 'checkpoints') os.makedirs(checkpointDir, exist_ok=True) os.makedirs(args.ensembleDir, exist_ok=True) with EventTimer('Preparing for dataset / dataloader'): trainDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.trainImages), transform=trainingPreprocessing) validDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.validImages), transform=inferencePreprocessing) trainDataloader = DataLoader(trainDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=True) validDataloader = DataLoader(validDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=False) print(f'> Training dataset:\t{len(trainDataset)}') print(f'> Validation dataset:\t{len(validDataset)}') with EventTimer(f'Load pretrained model - {args.pretrainModel}'): model = models.GetPretrainedModel(args.pretrainModel, fcDims=args.fcDims + [42]) print(model) #torchsummary will crash under densenet, skip the summary. #torchsummary.summary(model, (3, 224, 224), device='cpu') with EventTimer(f'Train model'): model.cuda() criterion = CrossEntropyLoss() optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.l2) scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) history = [] if args.retrain != 0: checkpoint = torch.load( os.path.join(checkpointDir, f'checkpoint-{args.retrain:03d}.pt')) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) history = checkpoint['history'] def runEpoch(dataloader, train=False, name=''): # For empty validation dataloader if len(dataloader) == 0: return 0, 0 # Enable grad with (torch.enable_grad() if train else torch.no_grad()): if train: model.train() else: model.eval() losses = [] for img, label, imgPath in tqdm(dataloader, desc=name, ncols=80): if train: optimizer.zero_grad() output = model(img.cuda()).cpu() loss = criterion(output, label) if train: loss.backward() optimizer.step() accu = accuracy(output.data.numpy(), label.numpy()) losses.append((loss.item(), accu)) return map(np.mean, zip(*losses)) def cleanUp(): model.eval() train_pred = np.zeros((trainDataloader.__len__()) * args.batchSize) cnt = 0 for i, (data, label, path) in enumerate(trainDataloader): test_pred = model(data.cuda()) pred = np.max(test_pred.cpu().data.numpy(), axis=1) train_pred[cnt:cnt + len(pred)] = pred cnt += len(pred) sorted_pred = train_pred sorted_pred.sort() threshold = sorted_pred[(len(sorted_pred) // 20)] data_set = [[], []] for i, (data, label, path) in enumerate(trainDataloader): test_pred = model(data.cuda()) pred = np.max(test_pred.cpu().data.numpy(), axis=1) for j in range(len(pred)): if pred[j] >= threshold: data_set[0].append(path[j]) data_set[1].append(label[j]) newDataset = ProductDataset(os.path.join(args.dataDir, 'train'), os.path.join(args.trainImages), transform=trainingPreprocessing, data=data_set) newDataloader = DataLoader(newDataset, batch_size=args.batchSize, num_workers=args.numWorkers, shuffle=True) print( f"{newDataloader.__len__() * args.batchSize} images remain after cleanup" ) return newDataloader for epoch in range(args.retrain + 1, args.epochs + 1): with EventTimer(verbose=False) as et: print(f'====== Epoch {epoch:3d} / {args.epochs:3d} ======') trainLoss, trainAccu = runEpoch(trainDataloader, train=True, name='training ') validLoss, validAccu = runEpoch(validDataloader, name='validation') history.append( ((trainLoss, trainAccu), (validLoss, validAccu))) scheduler.step() print( f'[{et.gettime():.4f}s] Training: {trainLoss:.6f} / {trainAccu:.4f} ; Validation {validLoss:.6f} / {validAccu:.4f}' ) if args.cleanup and epoch % args.cleanup_epoch == 0: with EventTimer('Cleaning Training Set'): trainDataloader = cleanUp() if epoch % 5 == 0: torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, }, os.path.join(checkpointDir, f'checkpoint-{epoch:03d}.pt')) # save model as its coressponding name torch.save(model.state_dict(), os.path.join(args.modelDir, 'model-weights.pt')) utils.pickleSave(history, os.path.join(args.modelDir, 'history.pkl'))
class Training: def __init__(self, model, device, config, name, fold_num, imsize): self.config = config self.epoch = 0 self.base_dir = './models/' os.makedirs('./models', exist_ok=True) self.model = model self.best_loss = 10**5 self.device = device self.name = name self.fold_num = fold_num self.imsize = imsize # optimize param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.001 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.00 }] self.optimizer = AdamW(self.model.parameters(), lr=config.lr) self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params) # Earlystopping self.patience = config.patience # GradScaler self.scaler = GradScaler() def train_one_epoch(self, train_loader): self.model.train() showloss = Showloss() for step, (images, targets) in tqdm(enumerate(train_loader), total=len(train_loader)): self.optimizer.zero_grad() with autocast(): images = torch.stack( images) # 이미지들을 합쳐 Batch 생성 (default: dim=0) [B,C,H,W] images = images.to(self.device).float() batch_size = images.shape[0] boxes = [ target['bbox'].to(self.device).float() for target in targets ] labels = [ target['cls'].to(self.device).float() for target in targets ] img_scale = torch.tensor([ target['img_scale'].to(self.device).float() for target in targets ]) img_size = torch.tensor([ (self.imsize, self.imsize) for target in targets ]).to(self.device).float() # update 후로 forward는 image와 target_dict를 인자로 받음 target_res = {} target_res['bbox'] = boxes target_res['cls'] = labels target_res['img_scale'] = img_scale target_res['img_size'] = img_size # pred output = self.model(images, target_res) loss = output['loss'] showloss.update(loss.detach().item(), batch_size) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() return showloss def val_one_epoch(self, val_loader): self.model.eval() showloss = Showloss() for step, (images, targets) in tqdm(enumerate(val_loader), total=len(val_loader)): with torch.no_grad(): images = torch.stack(images) batch_size = images.shape[0] images = images.to(self.device).float() boxes = [ target['bbox'].to(self.device).float() for target in targets ] labels = [ target['cls'].to(self.device).float() for target in targets ] img_scale = torch.tensor([ target['img_scale'].to(self.device).float() for target in targets ]) img_size = torch.tensor([ (self.imsize, self.imsize) for target in targets ]).to(self.device).float() target_res = {} target_res['bbox'] = boxes target_res['cls'] = labels target_res['img_scale'] = img_scale target_res['img_size'] = img_size # loss, _, _ = self.model(images, boxes, labels) output = self.model(images, target_res) loss = output['loss'] showloss.update(loss.detach().item(), batch_size) return showloss def save(self, path): # 모델 및 파라미터 저장 self.model.eval() torch.save( { 'model_state_dict': self.model.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'loss': self.best_loss, # val 'epoch': self.epoch, }, path) def load(self, path): checkpoint = torch.load(path) self.model.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_loss = checkpoint['best_loss'] # val self.epoch = checkpoint['epoch'] + 1 def fit(self, train_loader, val_loader): early_stopping = EarlyStopping(self.patience) for epoch in range(self.config.n_epochs): print('{} / {} Epoch'.format(epoch, self.config.n_epochs)) train_loss = self.train_one_epoch(train_loader) print('[Train] loss: {}'.format(train_loss.avg)) self.save(self.base_dir + '{}_{}_last.pt'.format(self.name, self.fold_num)) val_loss = self.val_one_epoch(val_loader) print('[Valid] loss: {}'.format(val_loss.avg)) if val_loss.avg < self.best_loss: self.best_loss = val_loss.avg self.save(self.base_dir + '{}_{}_best.pt'.format(self.name, self.fold_num)) # Early stopping early_stopping(val_loss.avg, self.best_loss) if early_stopping.early_stop: break if self.config.val_scheduler: self.scheduler.step(metrics=val_loss.avg) self.epoch += 1
class Trainer(): def __init__(self, config, pretrained=True): self.config = config self.model, self.vocab = build_model(config) self.device = config['device'] self.num_iters = config['trainer']['iters'] self.beamsearch = config['predictor']['beamsearch'] self.data_root = config['dataset']['data_root'] self.train_annotation = config['dataset']['train_annotation'] self.valid_annotation = config['dataset']['valid_annotation'] self.dataset_name = config['dataset']['name'] self.batch_size = config['trainer']['batch_size'] self.print_every = config['trainer']['print_every'] self.valid_every = config['trainer']['valid_every'] self.checkpoint = config['trainer']['checkpoint'] self.export_weights = config['trainer']['export'] self.metrics = config['trainer']['metrics'] logger = config['trainer']['log'] if logger: self.logger = Logger(logger) if pretrained: weight_file = download_weights(**config['pretrain'], quiet=config['quiet']) self.load_weights(weight_file) self.iter = 0 self.optimizer = AdamW(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09) self.scheduler = OneCycleLR(self.optimizer, **config['optimizer']) # self.optimizer = ScheduledOptim( # Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), # #config['transformer']['d_model'], # 512, # **config['optimizer']) self.criterion = LabelSmoothingLoss(len(self.vocab), padding_idx=self.vocab.pad, smoothing=0.1) transforms = ImgAugTransform() self.train_gen = self.data_gen('train_{}'.format(self.dataset_name), self.data_root, self.train_annotation, transform=transforms) if self.valid_annotation: self.valid_gen = self.data_gen( 'valid_{}'.format(self.dataset_name), self.data_root, self.valid_annotation) self.train_losses = [] def train(self): total_loss = 0 total_loader_time = 0 total_gpu_time = 0 best_acc = 0 data_iter = iter(self.train_gen) for i in range(self.num_iters): self.iter += 1 start = time.time() try: batch = next(data_iter) except StopIteration: data_iter = iter(self.train_gen) batch = next(data_iter) total_loader_time += time.time() - start start = time.time() loss = self.step(batch) total_gpu_time += time.time() - start total_loss += loss self.train_losses.append((self.iter, loss)) if self.iter % self.print_every == 0: info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format( self.iter, total_loss / self.print_every, self.optimizer.param_groups[0]['lr'], total_loader_time, total_gpu_time) total_loss = 0 total_loader_time = 0 total_gpu_time = 0 print(info) self.logger.log(info) if self.valid_annotation and self.iter % self.valid_every == 0: val_loss = self.validate() acc_full_seq, acc_per_char = self.precision(self.metrics) info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format( self.iter, val_loss, acc_full_seq, acc_per_char) print(info) self.logger.log(info) if acc_full_seq > best_acc: self.save_weights(self.export_weights) best_acc = acc_full_seq def validate(self): self.model.eval() total_loss = [] with torch.no_grad(): for step, batch in enumerate(self.valid_gen): batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch[ 'img'], batch['tgt_input'], batch['tgt_output'], batch[ 'tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.flatten(0, 1) tgt_output = tgt_output.flatten() loss = self.criterion(outputs, tgt_output) total_loss.append(loss.item()) del outputs del loss total_loss = np.mean(total_loss) self.model.train() return total_loss def predict(self, sample=None): pred_sents = [] actual_sents = [] img_files = [] for batch in self.valid_gen: batch = self.batch_to_device(batch) if self.beamsearch: translated_sentence = batch_translate_beam_search( batch['img'], self.model) else: translated_sentence = translate(batch['img'], self.model) pred_sent = self.vocab.batch_decode(translated_sentence.tolist()) actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist()) img_files.extend(batch['filenames']) pred_sents.extend(pred_sent) actual_sents.extend(actual_sent) if sample != None and len(pred_sents) > sample: break return pred_sents, actual_sents, img_files def precision(self, sample=None): pred_sents, actual_sents, _ = self.predict(sample=sample) acc_full_seq = compute_accuracy(actual_sents, pred_sents, mode='full_sequence') acc_per_char = compute_accuracy(actual_sents, pred_sents, mode='per_char') return acc_full_seq, acc_per_char def visualize_prediction(self, sample=16, errorcase=False, fontname='serif', fontsize=16): pred_sents, actual_sents, img_files = self.predict(sample) if errorcase: wrongs = [] for i in range(len(img_files)): if pred_sents[i] != actual_sents[i]: wrongs.append(i) pred_sents = [pred_sents[i] for i in wrongs] actual_sents = [actual_sents[i] for i in wrongs] img_files = [img_files[i] for i in wrongs] img_files = img_files[:sample] fontdict = {'family': fontname, 'size': fontsize} for vis_idx in range(0, len(img_files)): img_path = img_files[vis_idx] pred_sent = pred_sents[vis_idx] actual_sent = actual_sents[vis_idx] img = Image.open(open(img_path, 'rb')) plt.figure() plt.imshow(img) plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent), loc='left', fontdict=fontdict) plt.axis('off') plt.show() def visualize_dataset(self, sample=16, fontname='serif'): n = 0 for batch in self.train_gen: for i in range(self.batch_size): img = batch['img'][i].numpy().transpose(1, 2, 0) sent = self.vocab.decode(batch['tgt_input'].T[i].tolist()) plt.figure() plt.title('sent: {}'.format(sent), loc='center', fontname=fontname) plt.imshow(img) plt.axis('off') n += 1 if n >= sample: plt.show() return def load_checkpoint(self, filename): checkpoint = torch.load(filename) optim = ScheduledOptim( Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09), self.config['transformer']['d_model'], **self.config['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.model.load_state_dict(checkpoint['state_dict']) self.iter = checkpoint['iter'] self.train_losses = checkpoint['train_losses'] def save_checkpoint(self, filename): state = { 'iter': self.iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'train_losses': self.train_losses } path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(state, filename) def load_weights(self, filename): state_dict = torch.load(filename, map_location=torch.device(self.device)) for name, param in self.model.named_parameters(): if name not in state_dict: print('{} not found'.format(name)) elif state_dict[name].shape != param.shape: print('{} missmatching shape'.format(name)) del state_dict[name] self.model.load_state_dict(state_dict, strict=False) def save_weights(self, filename): path, _ = os.path.split(filename) os.makedirs(path, exist_ok=True) torch.save(self.model.state_dict(), filename) def batch_to_device(self, batch): img = batch['img'].to(self.device, non_blocking=True) tgt_input = batch['tgt_input'].to(self.device, non_blocking=True) tgt_output = batch['tgt_output'].to(self.device, non_blocking=True) tgt_padding_mask = batch['tgt_padding_mask'].to(self.device, non_blocking=True) batch = { 'img': img, 'tgt_input': tgt_input, 'tgt_output': tgt_output, 'tgt_padding_mask': tgt_padding_mask, 'filenames': batch['filenames'] } return batch def data_gen(self, lmdb_path, data_root, annotation, transform=None): dataset = OCRDataset( lmdb_path=lmdb_path, root_dir=data_root, annotation_path=annotation, vocab=self.vocab, transform=transform, image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) sampler = ClusterRandomSampler(dataset, self.batch_size, True) gen = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=collate_fn, shuffle=False, drop_last=False, **self.config['dataloader']) return gen def data_gen_v1(self, lmdb_path, data_root, annotation): data_gen = DataGen( data_root, annotation, self.vocab, 'cpu', image_height=self.config['dataset']['image_height'], image_min_width=self.config['dataset']['image_min_width'], image_max_width=self.config['dataset']['image_max_width']) return data_gen def step(self, batch): self.model.train() batch = self.batch_to_device(batch) img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[ 'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask'] outputs = self.model(img, tgt_input, tgt_key_padding_mask=tgt_padding_mask) # loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)')) outputs = outputs.view(-1, outputs.size(2)) #flatten(0, 1) tgt_output = tgt_output.view(-1) #flatten() loss = self.criterion(outputs, tgt_output) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() self.scheduler.step() loss_item = loss.item() return loss_item
}, ] lr = args.lr query_optimizer = AdamW(optimizer_grouped_parameters1, lr=lr, eps=1e-8) t_total = epoch_len * args.epochs num_warmup_steps = int(args.warmup * t_total) query_scheduler = get_linear_schedule_with_warmup( query_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) if (ckpt_dir and os.path.isfile(os.path.join(ckpt_dir, "query_optimizer.pt")) and os.path.isfile(os.path.join(ckpt_dir, "query_scheduler.pt"))): # Load in optimizer and scheduler states query_optimizer.load_state_dict( torch.load(os.path.join(ckpt_dir, "query_optimizer.pt"), map_location='cpu')) query_scheduler.load_state_dict( torch.load(os.path.join(ckpt_dir, "query_scheduler.pt"), map_location='cpu')) logger.info( f'Load query optimizer states from {os.path.join(ckpt_dir, "query_optimizer.pt")}' ) if not args.share: optimizer_grouped_parameters2 = [ { "params": [ p for n, p in doc_bert.named_parameters() if not any(nd in n for nd in no_decay) ],
def train(args, train_dataset, model, tokenizer): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt")): # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to global_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) except ValueError: global_step = 0 epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if step % 10 == 0: print(step, loss.item()) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if ( args.local_rank == -1 and args.evaluate_during_training ): # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss print(json.dumps({**logs, **{"step": global_step}})) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break return global_step, tr_loss / global_step
class Trainer(): def __init__(self, train_dataloader, test_dataloader, lr, betas, weight_decay, log_freq, with_cuda, model=None): cuda_condition = torch.cuda.is_available() and with_cuda self.device = torch.device("cuda" if cuda_condition else "cpu") print("Use:", "cuda:0" if cuda_condition else "cpu") self.model = Classifier_M3().to(self.device) self.optim = AdamW(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) self.scheduler = lr_scheduler.CosineAnnealingLR(self.optim, 5) self.criterion = nn.BCEWithLogitsLoss() if model != None: checkpoint = torch.load(model) self.model.load_state_dict(checkpoint['model_state_dict']) self.optim.load_state_dict(checkpoint['optimizer_state_dict']) self.epoch = checkpoint['epoch'] self.criterion = checkpoint['loss'] if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) print("Using %d GPUS for Converter" % torch.cuda.device_count()) self.train_data = train_dataloader self.test_data = test_dataloader self.log_freq = log_freq print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) self.test_loss = [] self.train_loss = [] self.train_f1_score = [] self.test_f1_score = [] def train(self, epoch): self.iteration(epoch, self.train_data) def test(self, epoch): self.iteration(epoch, self.test_data, train=False) def iteration(self, epoch, data_loader, train=True): """ :param epoch: 現在のepoch :param data_loader: torch.utils.data.DataLoader :param train: trainかtestかのbool値 """ str_code = "train" if train else "test" data_iter = tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}") total_element = 0 loss_store = 0.0 f1_score_store = 0.0 total_correct = 0 for i, data in data_iter: specgram = data[0].to(self.device) label = data[2].to(self.device) one_hot_label = data[1].to(self.device) predict_label = self.model(specgram, train) # predict_f1_score = get_F1_score( label.cpu().detach().numpy(), convert_label(predict_label.cpu().detach().numpy()), average='micro') loss = self.criterion(predict_label, one_hot_label) # if train: self.optim.zero_grad() loss.backward() self.optim.step() self.scheduler.step() loss_store += loss.item() f1_score_store += predict_f1_score self.avg_loss = loss_store / (i + 1) self.avg_f1_score = f1_score_store / (i + 1) post_fix = { "epoch": epoch, "iter": i, "avg_loss": round(self.avg_loss, 5), "loss": round(loss.item(), 5), "avg_f1_score": round(self.avg_f1_score, 5) } data_iter.write(str(post_fix)) self.train_loss.append( self.avg_loss) if train else self.test_loss.append(self.avg_loss) self.train_f1_score.append( self.avg_f1_score) if train else self.test_f1_score.append( self.avg_f1_score) def save(self, epoch, file_path="../models/2k/"): """ """ output_path = file_path + f"crnn_ep{epoch}.model" torch.save( { 'epoch': epoch, 'model_state_dict': self.model.cpu().state_dict(), 'optimizer_state_dict': self.optim.state_dict(), 'criterion': self.criterion }, output_path) self.model.to(self.device) print("EP:%d Model Saved on:" % epoch, output_path) return output_path def export_log(self, epoch, file_path="../../logs/2k/"): df = pd.DataFrame({ "train_loss": self.train_loss, "test_loss": self.test_loss, "train_F1_score": self.train_f1_score, "test_F1_score": self.test_f1_score }) output_path = file_path + f"loss_timestrech.log" print("EP:%d logs Saved on:" % epoch, output_path) df.to_csv(output_path)
def run_training(opt): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') work_dir, epochs, train_batch, valid_batch, weights = \ opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights # Directories last = os.path.join(work_dir, 'last.pt') best = os.path.join(work_dir, 'best.pt') # -------------------------------------- # Setup train and validation set # -------------------------------------- data = pd.read_csv(opt.train_csv) images_path = opt.data_dir n_classes = 6 # fixed coding :V data['class'] = data.apply(lambda row: categ[row["class"]], axis=1) train_loader, val_loader = prepare_dataloader(data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # if not opt.ovr_val: # handwritten_data = pd.read_csv(opt.handwritten_csv) # printed_data = pd.read_csv(opt.printed_csv) # handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1) # printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1) # _, handwritten_val_loader = prepare_dataloader( # handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # _, printed_val_loader = prepare_dataloader( # printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path) # -------------------------------------- # Models # -------------------------------------- model = Classifier(model_name=opt.model_name, n_classes=n_classes, pretrained=True).to(device) if opt.weights is not None: cp = torch.load(opt.weights) model.load_state_dict(cp['model']) # ------------------------------------------- # Setup optimizer, scheduler, criterion loss # ------------------------------------------- optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1) scaler = GradScaler() loss_tr = nn.CrossEntropyLoss().to(device) loss_fn = nn.CrossEntropyLoss().to(device) # -------------------------------------- # Setup training # -------------------------------------- if os.path.exists(work_dir) == False: os.mkdir(work_dir) best_loss = 1e5 start_epoch = 0 best_epoch = 0 # for early stopping if opt.resume == True: checkpoint = torch.load(last) start_epoch = checkpoint["epoch"] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint["scheduler"]) best_loss = checkpoint["best_loss"] # -------------------------------------- # Start training # -------------------------------------- print("[INFO] Start training...") for epoch in range(start_epoch, epochs): train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, scaler=scaler) with torch.no_grad(): if opt.ovr_val: val_loss = valid_one_epoch_overall(epoch, model, loss_fn, val_loader, device, scheduler=None) else: val_loss = valid_one_epoch(epoch, model, loss_fn, handwritten_val_loader, printed_val_loader, device, scheduler=None) if val_loss < best_loss: best_loss = val_loss best_epoch = epoch torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'best_loss': best_loss }, os.path.join(best)) print('best model found for epoch {}'.format(epoch + 1)) torch.save( { 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'best_loss': best_loss }, os.path.join(last)) if epoch - best_epoch > opt.patience: print("Early stop achieved at", epoch + 1) break del model, optimizer, train_loader, val_loader, scheduler, scaler torch.cuda.empty_cache()
break_factor = False # Measure the total training time for the whole run. total_t0 = time.time() print("starting...") # For each epoch... for epoch_i in range(0, EPOCHS): print("") print('======== Epoch {:} / {:} ========'.format(epoch_i + 5, EPOCHS + 4)) print('Training...') checkpoint = torch.load("/global/cscratch1/sd/ajaybati/model_ckptDS" + str(epoch_i) + ".pickle") model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] + 1 total_train_loss = 0 step_resume = 0 training_stats = checkpoint['training_stats'] print('step: ', step_resume, 'total loss: ', total_train_loss, 'epoch: ', epoch) # Measure how long the training epoch takes. t0 = time.time() model.train() # For each batch of training data... for step, batch in enumerate(train_dataloader): b_input_ids = batch[0].to(device)