def init_trainer_network(self): self.reset_parameters() self.gen_paramas = list(self.generator.parameters()) self.dsc_params = list(self.discriminator.parameters()) self.language_params = list(self.language_model.parameters()) self.style_params = list(self.style_model.parameters()) self.content_params = list(self.content_model.parameters()) self.optim_language = AdamW(self.language_params, lr=self.cfg.lang_lr, betas=self.cfg.lang_betas, weight_decay=self.cfg.lang_wd) self.optim_generator = AdamW(self.gen_paramas, lr=self.cfg.gen_lr, betas=self.cfg.gen_betas, weight_decay=self.cfg.gen_wd) self.optim_discriminator = AdamW(self.dsc_params, lr=self.cfg.dsc_lr, betas=self.cfg.dsc_betas, weight_decay=self.cfg.dsc_wd) self.optim_style = AdamW(self.style_params, lr=self.cfg.lmf_lr, betas=self.cfg.lmf_betas, weight_decay=self.cfg.lmf_wd) self.optim_content = AdamW(self.content_params, lr=self.cfg.lmf_lr, betas=self.cfg.lmf_betas, weight_decay=self.cfg.lmf_wd)
def init_trainer_network(self): images = self.forward_generator() _ = self.forward_discriminator(images, self.condition) self.reset_parameters() gen_paramas = list(self.generator.parameters()) dsc_params = list(self.discriminators.parameters()) self.optim_generator = AdamW(gen_paramas, lr=self.cfg.gen_lr, betas=self.cfg.gen_betas, weight_decay=self.cfg.gen_wd) self.optim_discriminator = AdamW(dsc_params, lr=self.cfg.dsc_lr, betas=self.cfg.dsc_betas, weight_decay=self.cfg.dsc_wd) # TODO add to yaml params for language optimizer feature_params = list(self.features_model.parameters()) if self.language_model.trainable: feature_params = feature_params + list( self.language_model.parameters()) self.optim_feature = AdamW(feature_params, lr=self.cfg.gen_lr, betas=self.cfg.gen_betas, weight_decay=self.cfg.gen_wd)
def __init__( self, save_path, log_path, d_features, d_meta, max_length, d_classifier, n_classes, threshold=None, optimizer=None, **kwargs): '''**kwargs: n_layers, n_head, dropout, use_bottleneck, d_bottleneck''' ''' Arguments: save_path -- model file path log_path -- log file path d_features -- how many PMs d_meta -- how many facility types max_length -- input sequence length d_classifier -- classifier hidden unit n_classes -- output dim threshold -- if not None, n_classes should be 1 (regression). ''' super().__init__(save_path, log_path) self.d_output = n_classes self.threshold = threshold self.max_length = max_length # ----------------------------- Model ------------------------------ # self.model = Encoder(TimeFacilityEncoding, d_features=d_features, max_seq_length=max_length, d_meta=d_meta, **kwargs) # --------------------------- Classifier --------------------------- # self.classifier = LinearClassifier(d_features * max_length, d_classifier, n_classes) # ------------------------------ CUDA ------------------------------ # # If GPU available, move the graph to GPU(s) self.CUDA_AVAILABLE = self.check_cuda() if self.CUDA_AVAILABLE: device_ids = list(range(torch.cuda.device_count())) self.model = nn.DataParallel(self.model, device_ids) self.classifier = nn.DataParallel(self.classifier, device_ids) self.model.to('cuda') self.classifier.to('cuda') assert (next(self.model.parameters()).is_cuda) assert (next(self.classifier.parameters()).is_cuda) pass else: print('CUDA not found or not enabled, use CPU instead') # ---------------------------- Optimizer --------------------------- # self.parameters = list(self.model.parameters()) + list(self.classifier.parameters()) if optimizer == None: self.optimizer = AdamW(self.parameters, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001) # ------------------------ training control ------------------------ # self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10) self.early_stopping = EarlyStopping(patience=50) # --------------------- logging and tensorboard -------------------- # self.count_parameters() self.set_logger() self.set_summary_writer()
def configure_optimizers(self): opt_cfg = self.cfg['optimizer'] lr = float(self.cfg['optimizer']['lr']) if opt_cfg['name'] == 'AdamW': optimizer = AdamW(self.model.parameters(), lr=lr, ) elif opt_cfg['name'] == 'Adam_GCC': optimizer = Adam_GCC(self.model.parameters(), lr=lr) elif opt_cfg['name'] == 'AdamW_GCC2': optimizer = AdamW_GCC2(self.model.parameters(), lr=lr) if self.cfg['scheduler']['type'] == 'none': sched = None elif self.cfg['scheduler']['type'] == 'CosineAnnealingWarmRestarts': T_mult = self.cfg['scheduler']['T_mult'] T_0 = self.cfg['scheduler']['T_0'] eta_min = float(self.cfg['scheduler']['eta_min']) sched = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=-1) elif self.cfg['scheduler']['type'] == 'OneCycleLR': max_lr = float(self.cfg['scheduler']['max_lr']) steps_per_epoch = cfg['scheduler']['steps_per_epoch'] epochs = cfg['scheduler']['epochs'] sched = OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=steps_per_epoch, epochs=epochs) else: raise Exception('scheduler {} not supported'.format(self.cfg['scheduler']['type'])) if sched is not None: sched = {'scheduler': sched, 'name': format(self.cfg['scheduler']['type'])} if sched is not None: return [optimizer], [sched] else: return optimizer return optimizer
def train_end_to_end(args, graph_dataset, model): optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5) early_stopper = EarlyStopper(args.supervised_model_patience) epoch = 0 start_time = time.time() LOG.info('Started training supervised model.') while True: train_loss = end_to_end_train_epoch(graph_dataset, model, optimizer, args, epoch) val_loss, val_acc, val_f1 = end_to_end_val_epoch(graph_dataset, model, args) lr_scheduler.step(val_loss) epoch_end_time = round(time.time() - start_time, 2) LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: ' f'Training loss = {train_loss:.4f}, Validation loss = {val_loss:.4f}, ' f'Val accuracy = {val_acc:.4f}, Val f1 = {val_f1:.4f}') epoch += 1 if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING: break # Compute metrics. val_y_pred = predict(graph_dataset, model, args, mode='val') val_y_pred = raw_output_to_prediction(val_y_pred, args.loss_fn) val_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.val_mask].cpu().numpy()] test_y_pred = predict(graph_dataset, model, args, mode='test') test_y_pred = raw_output_to_prediction(test_y_pred, args.loss_fn) test_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.test_mask].cpu().numpy()] log_metrics(graph_dataset, val_y_true, val_y_pred.cpu().numpy(), test_y_true, test_y_pred.cpu().numpy(), args.show_detailed_metrics, args.attention_radius, args.max_neighbors) # Save trained end-to-end model. torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_end_to_end.pt'))
def get_optimizer(parameters, optim, learning_rate, lr_decay, amsgrad, weight_decay, warmup_steps): if optim == 'sgd': optimizer = SGD(parameters, lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True) else: optimizer = AdamW(parameters, lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, amsgrad=amsgrad, weight_decay=weight_decay) init_lr = 1e-7 scheduler = ExponentialScheduler(optimizer, lr_decay, warmup_steps, init_lr) return optimizer, scheduler
def __init__(self, save_path, log_path, d_features, d_out_list, d_classifier, d_output, threshold=None, optimizer=None, **kwargs): '''*args: n_layers, n_head, n_channel, n_vchannel, dropout''' super().__init__(save_path, log_path) self.d_output = d_output self.threshold = threshold # ----------------------------- Model ------------------------------ # self.model = SelfAttentionFeatureSelection(d_features, d_out_list) # --------------------------- Classifier --------------------------- # self.classifier = LinearClassifier(d_features, d_classifier, d_output) # ------------------------------ CUDA ------------------------------ # self.data_parallel() # ---------------------------- Optimizer --------------------------- # self.parameters = list(self.model.parameters()) + list( self.classifier.parameters()) if optimizer == None: self.optimizer = AdamW(self.parameters, lr=0.002, betas=(0.9, 0.999), weight_decay=0.001) # ------------------------ training control ------------------------ # self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10) self.early_stopping = EarlyStopping(patience=50) # --------------------- logging and tensorboard -------------------- # self.set_logger() self.set_summary_writer()
def train_single_epoch(model: RedisSingleDNN, trainDataloader: DataLoader, optimizer: AdamW) -> Tuple[float, float]: train_loss = 0.0 train_ACC = 0 train_steps = 0 model.train() for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")): optimizer.zero_grad() knobs_with_info = batch[0].to(DEVICE) targets = batch[1].to(DEVICE) outputs = model(knobs_with_info) loss = F.mse_loss(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() train_steps += 1 return train_loss / len(trainDataloader), train_ACC
def configure_optimizers(self): optimizer = AdamW(self.optimizer_params) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=int(round(self.config.kv['lr_t0'])), T_mult=int(round(self.config.kv['lr_f']))) scheduler = WarmupScheduler(optimizer, epochs=self.config.kv['lr_w'], next_scheduler=scheduler) return [optimizer], [scheduler]
def get_optimizer(learning_rate, parameters, betas, eps, amsgrad, step_decay, weight_decay, warmup_steps, init_lr): optimizer = AdamW(parameters, lr=learning_rate, betas=betas, eps=eps, amsgrad=amsgrad, weight_decay=weight_decay) step_decay = step_decay scheduler = ExponentialScheduler(optimizer, step_decay, warmup_steps, init_lr) return optimizer, scheduler
def get_optimisers(G: "nn.Module", args: "tupperware") -> "Union[optim, lr_scheduler]": g_optimizer = AdamW(G.parameters(), lr=args.learning_rate, betas=(args.beta_1, args.beta_2)) g_lr_scheduler = CosineAnnealingWarmRestarts(optimizer=g_optimizer, T_0=args.T_0, T_mult=args.T_mult, eta_min=2e-10) return g_optimizer, g_lr_scheduler
def split_optimizer(model: nn.Module, cfg: dict): param_weight_decay, param_bias, param_other = split_params(model) if cfg['optimizer'] == 'Adam': optimizer = Adam(param_other, lr=cfg['lr']) elif cfg['optimizer'] == 'SGD': optimizer = SGD(param_other, lr=cfg['lr'], momentum=cfg['momentum']) elif cfg['optimizer'] == "AdamW": optimizer = AdamW(param_other, lr=cfg['lr']) else: raise NotImplementedError("optimizer {:s} is not support!".format(cfg['optimizer'])) optimizer.add_param_group( {'params': param_weight_decay, 'weight_decay': cfg['weight_decay']}) # add pg1 with weight_decay optimizer.add_param_group({'params': param_bias}) return optimizer
def load_parameters(self, filename, ctx=None, strict=True): ''' Load parameters to files. Args: filename: File name. ctx: Context-manager that changes the selected device. strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`. ''' checkpoint = torch.load(filename) self.load_state_dict(checkpoint['model_state_dict'], strict=strict) try: if self.optimizer is None: if self.optimizer_f is not None: self.optimizer = self.optimizer_f( self.parameters() ) else: self.optimizer = AdamW( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict'] ) except ValueError as e: self.logger.debug(e) self.logger.debug("The state of the optimizer in `TransformerModel` was not updated.") self.epoch = checkpoint['epoch'] self.__loss_list = checkpoint['loss'].tolist() if ctx is not None: self.to(ctx) self.__ctx = ctx
def __init__( self, config: TrainConfig, model: NSMCModel, train_data_loader: DataLoader, dev_data_loader: DataLoader, test_data_loader: DataLoader, logger: Logger, summary_writer: SummaryWriter, ): self.config = config self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = model self.model.to(self.device) self.train_data_loader = train_data_loader self.dev_data_loader = dev_data_loader self.test_data_loader = test_data_loader self.logger = logger self.summary_writer = summary_writer self.criterion = nn.CrossEntropyLoss() self.optimizer = AdamW(model.parameters(), lr=config.learning_rate) # total step 계산 self.steps_per_epoch = len(train_data_loader) self.total_steps = self.steps_per_epoch * config.num_epochs self.warmup_steps = config.warmup_step_ratio * self.total_steps self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_steps) self.global_step = 0
def train_stage_two(dataset, best_model_file, model_file): bestaccuracy = 0.9 device = 'cudo:0' if torch.cuda.is_available() else 'cpu' net = ResNet(BasicBlock, [3, 3, 4, 3]).to(device) # [2,2,2,2] net.train() for parameter in net.parameters(): if len(parameter.shape) > 1: torch.nn.init.xavier_uniform_(parameter) if isfile(best_model_file): net.load_state_dict(torch.load(best_model_file)) train_loader = DataLoader(dataset, batch_size=64, shuffle=True) optimizer = AdamW(net.parameters(), lr=0.0001) scheduler = CyclicLR(optimizer, 0.000001, 0.0001, step_size_up=200, mode='triangular2', cycle_momentum=False, last_epoch=-1) L1 = torch.nn.L1Loss() BCE = torch.nn.BCEWithLogitsLoss() for epoch in range(50): running_accuracy = [] for (images, targets) in tqdm(train_loader): images, targets = images.to(device), targets.to(device) optimizer.zero_grad() outputs = net(images) clsloss = BCE(outputs[:, 0], targets[:, 0]) regloss = L1(outputs[:, 1:], targets[:, 1:]) loss = clsloss + regloss cls_preds = np.greater(outputs[:, 0].cpu().detach().numpy(), 0) cls_truth = targets[:, 0].cpu().detach().numpy() correctness = np.equal(cls_preds, cls_truth).astype(int) accuracy = sum(correctness) / 64 running_accuracy.append(accuracy) running_accuracy = running_accuracy[-10:] print(' clsloss ' + str(clsloss.cpu().detach().numpy())[:4] + ' regloss ' + str(regloss.cpu().detach().numpy())[:4] + ' accuracy ' + str(np.mean(running_accuracy)), end='\r') if np.mean(running_accuracy) > bestaccuracy: bestaccuracy = np.mean(running_accuracy) torch.save(net.state_dict(), best_model_file) # print('totalloss', str(loss.detach().numpy())[:4], 'saved!', end = '\n') else: pass # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n') loss.backward() optimizer.step() scheduler.step(None) # if idx%5==0: # print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n') # idx+=1 torch.save(net.state_dict(), model_file) print(epoch)
def train_twice_epoch(model: RedisTwiceDNN, trainDataloader: DataLoader, optimizer: AdamW) -> Tuple[float, float]: train_loss = 0.0 train_ACC = 0 train_steps = 0 model.train() weight = [0.6, 0.4] for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")): optimizer.zero_grad() knobs_with_info = batch[0].to(DEVICE) targets = batch[1].to(DEVICE) outputs = model(knobs_with_info) loss = 0. for i, output in enumerate(outputs): loss += weight[i] * F.mse_loss(output.squeeze(1), targets[:, i]) loss.backward() optimizer.step() train_loss += loss.item() train_steps += 1 return train_loss / len(trainDataloader), train_ACC
def get_optimizer(opt, learning_rate, parameters, hyper1, hyper2, eps, rebound, lr_decay, decay_rate, milestone, weight_decay, weight_decay_type, warmup_updates, init_lr, last_lr, num_epochs, world_size): if opt == 'sgd': optimizer = SGD(parameters, lr=learning_rate, momentum=hyper1, weight_decay=weight_decay, nesterov=True) opt = 'momentum=%.1f, ' % (hyper1) weight_decay_type = 'L2' elif opt == 'radamw': optimizer = RAdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay) opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps) weight_decay_type = 'decoupled' elif opt == 'adamw': optimizer = AdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay) opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps) weight_decay_type = 'decoupled' elif opt == 'adabelief': optimizer = AdaBelief(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay, weight_decay_type=weight_decay_type) opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps) elif opt == 'apollo': optimizer = Apollo(parameters, lr=learning_rate, beta=hyper1, eps=eps, rebound=rebound, warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay, weight_decay_type=weight_decay_type) opt = 'beta=%.1f, eps=%.1e, rebound=%s, ' % (hyper1, eps, rebound) elif opt == 'adahessian': optimizer = AdaHessian(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay, num_threads=world_size) opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps) weight_decay_type = 'decoupled' else: raise ValueError('unknown optimizer: {}'.format(opt)) if lr_decay == 'exp': opt = opt + 'lr decay={}, decay rate={:.3f}, '.format(lr_decay, decay_rate) scheduler = ExponentialLR(optimizer, decay_rate) elif lr_decay == 'milestone': opt = opt + 'lr decay={} {}, decay rate={:.3f}, '.format(lr_decay, milestone, decay_rate) scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=decay_rate) elif lr_decay == 'cosine': opt = opt + 'lr decay={}, lr_min={}, '.format(lr_decay, last_lr) scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=last_lr) else: raise ValueError('unknown lr decay: {}'.format(lr_decay)) opt += 'warmup={}, init_lr={:.1e}, wd={:.1e} ({})'.format(warmup_updates, init_lr, weight_decay, weight_decay_type) return optimizer, scheduler, opt
def train(sqlite_file, model_selection_file, model_file): """ Selects the best model/training parameter combination from a model selection and trains a final image classifier on the entire annotated part of the image database. SQLITE_FILE: An annotated image database (see create-database). MODEL_SELECTION_FILE: Results of a model selection created by the "model-selection" tool. MODEL_FILE: Store the finally trained image classification model in this file. """ X, class_to_label, label_to_class = load_ground_truth(sqlite_file) X['file'] = X['file'].astype(str) y = X['class'].astype(int) batch_size, decrease_epochs, decrease_factor, epochs, model_name, num_trained, start_lr = load_model_selection( model_selection_file) model, device, fit_transform, predict_transform, logits_func = \ load_pretrained_model(model_name, len(label_to_class), num_trained) # optimizer = optim.SGD(model_ft.parameters(), lr=start_lr, momentum=momentum) optimizer = AdamW(model.parameters(), lr=start_lr) sched = lr_scheduler.StepLR(optimizer, step_size=decrease_epochs, gamma=decrease_factor) estimator = ImageClassifier(model=model, model_weights=copy.deepcopy( model.state_dict()), device=device, criterion=nn.CrossEntropyLoss(), optimizer=optimizer, scheduler=sched, fit_transform=fit_transform, predict_transform=predict_transform, batch_size=batch_size, logits_func=logits_func) for _ in range(0, epochs): estimator.fit(X, y) torch.save(model.state_dict(), model_file)
def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=self.learning_rate) if self.lr_warmup_steps is None: return optimizer scheduler = InverseSquareRootLR(optimizer, self.lr_warmup_steps) return ( [optimizer], [ { 'scheduler': scheduler, 'interval': 'step', 'frequency': 1, 'reduce_on_plateau': False, 'monitor': 'label_smoothed_val_loss' if self.optimize_on_smoothed_loss else 'val_loss', } ] )
def train_unsupervised(args, graph_dataset, model): optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5) early_stopper = EarlyStopper(args.unsupervised_model_patience) start_time = time.time() LOG.info("Started training unsupervised model.") for epoch in range(args.epochs): train_loss, train_count = node_encoder_train_epoch(graph_dataset, model, optimizer, epoch, args) val_loss, val_count = node_encoder_val_epoch(graph_dataset, model, epoch, args) lr_scheduler.step(val_loss) epoch_end_time = round(time.time() - start_time, 2) LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: Training loss [over {train_count} nodes] = {train_loss:.4f}, ' f'Validation loss [over {val_count} nodes] = {val_loss:.4f}') if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING: break LOG.info("Unsupervised training complete.") torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_unsupervised.pt'))
def _fit(self, modules: nn.ModuleDict, train_dl: DeviceDataLoader, valid_dl: DeviceDataLoader): r""" Fits \p modules' learners to the training and validation \p DataLoader objects """ self._configure_fit_vars(modules) for mod_name, module in modules.items(): lr = config.get_learner_val(mod_name, LearnerParams.Attribute.LEARNING_RATE) wd = config.get_learner_val(mod_name, LearnerParams.Attribute.WEIGHT_DECAY) is_lin_ff = config.DATASET.is_synthetic( ) and module.module.num_hidden_layers == 0 if is_lin_ff: module.optim = LBFGS(module.parameters(), lr=lr) else: module.optim = AdamW(module.parameters(), lr=lr, weight_decay=wd, amsgrad=True) logging.debug( f"{mod_name} Optimizer: {module.optim.__class__.__name__}") for ep in range(1, config.NUM_EPOCH + 1): # noinspection PyUnresolvedReferences for _, module in modules.items(): module.epoch_start() for batch in train_dl: for _, module in modules.items(): module.process_batch(batch) for _, module in modules.items(): module.calc_valid_loss(valid_dl) self._log_epoch(ep, modules) self._restore_best_model(modules) self.eval()
def build_optimizer(self, args: ClassifierArgs, **kwargs): no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay, }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) return optimizer
def create_optim(params, args=None, optim_name=None, lr=None): if args is None: assert optim_name is not None and lr is not None else: assert optim_name is None and lr is None optim_name = args.optim lr = args.lr if optim_name == 'sgd': return SGD(params, lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True) elif optim_name == 'adam': return Adam(params, lr=lr, weight_decay=0) elif optim_name == 'adamw': return AdamW(params, lr=lr, weight_decay=1e-2) elif optim_name == 'amsgrad': return Adam(params, lr=lr, weight_decay=0, amsgrad=True) elif optim_name == 'rmsprop': return RMSprop(params, lr=lr, momentum=0.9, weight_decay=0) else: raise NotImplementedError( 'Unsupported optimizer_memory: {}'.format(optim_name))
def configure_optimizers(self): return AdamW(params=self.model.parameters(), lr=2e-5, eps=1e-6, correct_bias=False)
def main(): def evaluate_accuracy(data_loader, prefix: str, save: bool = False): std_transform.eval() model.eval() pbar = tqdm(data_loader, desc=prefix, leave=True, total=len(data_loader)) num_corr = 0 num_tot = 0 for idx, batch in enumerate(pbar): batch = batch.to(device) scores = model(zmuv_transform(std_transform(batch.audio_data)), std_transform.compute_lengths(batch.lengths)) num_tot += scores.size(0) labels = torch.tensor([ l % SETTINGS.training.num_labels for l in batch.labels.tolist() ]).to(device) num_corr += (scores.max(1)[1] == labels).float().sum().item() acc = num_corr / num_tot pbar.set_postfix(accuracy=f'{acc:.4}') if save and not args.eval: writer.add_scalar(f'{prefix}/Metric/acc', acc, epoch_idx) ws.increment_model(model, acc / 10) apb = ArgumentParserBuilder() apb.add_options( opt('--model', type=str, choices=model_names(), default='las'), opt('--workspace', type=str, default=str(Path('workspaces') / 'default')), opt('--load-weights', action='store_true'), opt('--eval', action='store_true')) args = apb.parser.parse_args() ws = Workspace(Path(args.workspace), delete_existing=not args.eval) writer = ws.summary_writer set_seed(SETTINGS.training.seed) loader = GoogleSpeechCommandsDatasetLoader() sr = SETTINGS.audio.sample_rate ds_kwargs = dict(sr=sr, mono=SETTINGS.audio.use_mono) train_ds, dev_ds, test_ds = loader.load_splits( SETTINGS.dataset.dataset_path, **ds_kwargs) sr = SETTINGS.audio.sample_rate device = torch.device(SETTINGS.training.device) std_transform = StandardAudioTransform().to(device).eval() zmuv_transform = ZmuvTransform().to(device) batchifier = partial(batchify, label_provider=lambda x: x.label) truncater = partial(truncate_length, length=int(SETTINGS.training.max_window_size_seconds * sr)) train_comp = compose(truncater, TimeshiftTransform().train(), NoiseTransform().train(), batchifier) prep_dl = StandardAudioDataLoaderBuilder(train_ds, collate_fn=batchifier).build(1) prep_dl.shuffle = True train_dl = StandardAudioDataLoaderBuilder( train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size) dev_dl = StandardAudioDataLoaderBuilder( dev_ds, collate_fn=compose(truncater, batchifier)).build(SETTINGS.training.batch_size) test_dl = StandardAudioDataLoaderBuilder( test_ds, collate_fn=compose(truncater, batchifier)).build(SETTINGS.training.batch_size) model = find_model(args.model)().to(device) params = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay) logging.info(f'{sum(p.numel() for p in params)} parameters') criterion = nn.CrossEntropyLoss() if (ws.path / 'zmuv.pt.bin').exists(): zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin'))) else: for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')): batch.to(device) zmuv_transform.update(std_transform(batch.audio_data)) if idx == 2000: # TODO: quick debugging, remove later break logging.info( dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std)) torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin')) if args.load_weights: ws.load_model(model, best=True) if args.eval: ws.load_model(model, best=True) evaluate_accuracy(dev_dl, 'Dev') evaluate_accuracy(test_dl, 'Test') return ws.write_args(args) ws.write_setting(SETTINGS) writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params)) for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True): model.train() std_transform.train() pbar = tqdm(train_dl, total=len(train_dl), position=1, desc='Training', leave=True) for batch in pbar: batch.to(device) audio_data = zmuv_transform(std_transform(batch.audio_data)) scores = model(audio_data, std_transform.compute_lengths(batch.lengths)) optimizer.zero_grad() model.zero_grad() labels = torch.tensor([ l % SETTINGS.training.num_labels for l in batch.labels.tolist() ]).to(device) loss = criterion(scores, labels) loss.backward() optimizer.step() pbar.set_postfix(dict(loss=f'{loss.item():.3}')) writer.add_scalar('Training/Loss', loss.item(), epoch_idx) for group in optimizer.param_groups: group['lr'] *= SETTINGS.training.lr_decay evaluate_accuracy(dev_dl, 'Dev', save=True) evaluate_accuracy(test_dl, 'Test')
def main(): def evaluate_engine(dataset: WakeWordDataset, prefix: str, save: bool = False, positive_set: bool = False, write_errors: bool = True, mixer: DatasetMixer = None): std_transform.eval() if use_frame: engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000), int(SETTINGS.training.eval_stride_size_seconds * 1000), SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring) else: engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring, blank_idx=ctx.blank_label) model.eval() conf_matrix = ConfusionMatrix() pbar = tqdm(dataset, desc=prefix) if write_errors: with (ws.path / 'errors.tsv').open('a') as f: print(prefix, file=f) for idx, ex in enumerate(pbar): if mixer is not None: ex, = mixer([ex]) audio_data = ex.audio_data.to(device) engine.reset() seq_present = engine.infer(audio_data) if seq_present != positive_set and write_errors: with (ws.path / 'errors.tsv').open('a') as f: f.write(f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n') conf_matrix.increment(seq_present, positive_set) pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}')) logging.info(f'{conf_matrix}') if save and not args.eval: writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx) ws.increment_model(model, conf_matrix.tp) if args.eval: threshold = engine.threshold with (ws.path / (str(round(threshold, 2)) + '_results.csv') ).open('a') as f: f.write(f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n') def do_evaluate(): evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True) evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False) if SETTINGS.training.use_noise_dataset: evaluate_engine(ww_dev_pos_ds, 'Dev noisy positive', positive_set=True, mixer=dev_mixer) evaluate_engine(ww_dev_neg_ds, 'Dev noisy negative', positive_set=False, mixer=dev_mixer) evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True) evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False) if SETTINGS.training.use_noise_dataset: evaluate_engine(ww_test_pos_ds, 'Test noisy positive', positive_set=True, mixer=test_mixer) evaluate_engine(ww_test_neg_ds, 'Test noisy negative', positive_set=False, mixer=test_mixer) apb = ArgumentParserBuilder() apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'), opt('--workspace', type=str, default=str(Path('workspaces') / 'default')), opt('--load-weights', action='store_true'), opt('--load-last', action='store_true'), opt('--no-dev-per-epoch', action='store_false', dest='dev_per_epoch'), opt('--dataset-paths', '-i', type=str, nargs='+', default=[SETTINGS.dataset.dataset_path]), opt('--eval', action='store_true')) args = apb.parser.parse_args() use_frame = SETTINGS.training.objective == 'frame' ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame) if use_frame: batchifier = WakeWordFrameBatchifier(ctx.negative_label, window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000)) criterion = nn.CrossEntropyLoss() else: tokenizer = WakeWordTokenizer(ctx.vocab, ignore_oov=False) batchifier = AudioSequenceBatchifier(ctx.negative_label, tokenizer) criterion = nn.CTCLoss(ctx.blank_label) ws = Workspace(Path(args.workspace), delete_existing=not args.eval) writer = ws.summary_writer set_seed(SETTINGS.training.seed) loader = WakeWordDatasetLoader() ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=ctx.labeler) ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \ WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \ WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs) for ds_path in args.dataset_paths: ds_path = Path(ds_path) train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs) ww_train_ds.extend(train_ds) ww_dev_ds.extend(dev_ds) ww_test_ds.extend(test_ds) print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds) ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True) ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True) ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True) ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True) print_stats(f'Dev dataset', ww_dev_pos_ds, ww_dev_neg_ds) print_stats(f'Test dataset', ww_test_pos_ds, ww_test_neg_ds) device = torch.device(SETTINGS.training.device) std_transform = StandardAudioTransform().to(device).eval() zmuv_transform = ZmuvTransform().to(device) train_comp = (NoiseTransform().train(), batchifier) if SETTINGS.training.use_noise_dataset: noise_ds = RecursiveNoiseDatasetLoader().load(Path(SETTINGS.raw_dataset.noise_dataset_path), sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono) logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.') noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80)) noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50)) train_comp = (DatasetMixer(noise_ds_train).train(),) + train_comp dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False) test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False) train_comp = compose(*train_comp) prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=batchify).build(1) prep_dl.shuffle = True train_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size) model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).streaming() if SETTINGS.training.convert_static: model = ConvertedStaticModel(model, 40, 10) params = list(filter(lambda x: x.requires_grad, model.parameters())) optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay) logging.info(f'{sum(p.numel() for p in params)} parameters') if (ws.path / 'zmuv.pt.bin').exists(): zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin'))) else: for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')): batch.to(device) zmuv_transform.update(std_transform(batch.audio_data)) if idx == 2000: # TODO: quick debugging, remove later break logging.info(dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std)) torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin')) if args.load_weights: ws.load_model(model, best=not args.load_last) if args.eval: ws.load_model(model, best=not args.load_last) do_evaluate() return ws.write_args(args) ws.write_settings(SETTINGS) writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params)) for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True): model.train() std_transform.train() model.streaming_state = None pbar = tqdm(train_dl, total=len(train_dl), position=1, desc='Training', leave=True) total_loss = torch.Tensor([0.0]).to(device) for batch in pbar: batch.to(device) if use_frame: scores = model(zmuv_transform(std_transform(batch.audio_data)), std_transform.compute_lengths(batch.lengths)) loss = criterion(scores, batch.labels) else: lengths = std_transform.compute_lengths(batch.audio_lengths) scores = model(zmuv_transform(std_transform(batch.audio_data)), lengths) scores = F.log_softmax(scores, -1) # [num_frames x batch_size x num_labels] lengths = torch.tensor([model.compute_length(x.item()) for x in lengths]).to(device) loss = criterion(scores, batch.labels, lengths, batch.label_lengths) optimizer.zero_grad() model.zero_grad() loss.backward() optimizer.step() pbar.set_postfix(dict(loss=f'{loss.item():.3}')) with torch.no_grad(): total_loss += loss for group in optimizer.param_groups: group['lr'] *= SETTINGS.training.lr_decay mean = total_loss / len(train_dl) writer.add_scalar('Training/Loss', mean.item(), epoch_idx) writer.add_scalar('Training/LearningRate', group['lr'], epoch_idx) if args.dev_per_epoch: evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True, save=True, write_errors=False) do_evaluate()
if args.encoder == 'bert': tokenizer = get_bert_tokenizer( args.pretrain_checkpoint, add_tokens=['[EOT]'] ) two_tower_model = build_biencoder_model( model_tokenizer=tokenizer, args=args, aggregation='cls' ) else: raise Exception optimizer = AdamW( [{'params': two_tower_model.parameters(), 'initial_lr': args.lr}], lr=args.lr ) if args.best_acc > 0: two_tower_model, likelihood_criterion = get_saved_model_and_optimizer( model=two_tower_model, optimizer=optimizer, checkpoint_path=os.path.join( args.model_checkpoint, "%s_acc_%.5f" % (args.task_name, args.best_acc) )) # two_tower_model.load_state_dict(torch.load(os.path.join( # args.model_checkpoint, "%s_acc_%.5f" % (args.task_name, # args.best_acc) # ))) if args.distributed and args.device == 'cuda':
def __init__(self, save_path, log_path, n_depth, d_features, d_classifier, d_output, threshold=None, stack='ShuffleSelfAttention', expansion_layer='ChannelWiseConvExpansion', mode='1d', optimizer=None, **kwargs): '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck''' ''' Arguments: mode: 1d: 1d output 2d: 2d output residual: residual output dense: dense net ''' super().__init__(save_path, log_path) self.d_output = d_output self.threshold = threshold # ----------------------------- Model ------------------------------ # stack_dict = { 'ReactionAttention': ReactionAttentionStack, 'SelfAttention': SelfAttentionStack, 'Alternate': AlternateStack, 'Parallel': ParallelStack, 'ShuffleSelfAttention': ShuffleSelfAttentionStack, 'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2, } expansion_dict = { 'LinearExpansion': LinearExpansion, 'ReduceParamLinearExpansion': ReduceParamLinearExpansion, 'ConvExpansion': ConvExpansion, 'LinearConvExpansion': LinearConvExpansion, 'ShuffleConvExpansion': ShuffleConvExpansion, 'ChannelWiseConvExpansion': ChannelWiseConvExpansion, } self.model = stack_dict[stack](expansion_dict[expansion_layer], n_depth=n_depth, d_features=d_features, mode=mode, **kwargs) # --------------------------- Classifier --------------------------- # if mode == '1d': self.classifier = LinearClassifier(d_features, d_classifier, d_output) elif mode == '2d': self.classifier = LinearClassifier(n_depth * d_features, d_classifier, d_output) else: self.classifier = None # ------------------------------ CUDA ------------------------------ # # If GPU available, move the graph to GPU(s) self.CUDA_AVAILABLE = self.check_cuda() if self.CUDA_AVAILABLE: # self.model.cuda() # self.classifier.cuda() device_ids = list(range(torch.cuda.device_count())) self.model = nn.DataParallel(self.model, device_ids) self.classifier = nn.DataParallel(self.classifier, device_ids) self.model.to('cuda') self.classifier.to('cuda') assert (next(self.model.parameters()).is_cuda) assert (next(self.classifier.parameters()).is_cuda) pass else: print('CUDA not found or not enabled, use CPU instead') # ---------------------------- Optimizer --------------------------- # self.parameters = list(self.model.parameters()) + list( self.classifier.parameters()) if optimizer == None: self.optimizer = AdamW(self.parameters, lr=0.002, betas=(0.9, 0.999), weight_decay=0.001) # ------------------------ training control ------------------------ # self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10) self.early_stopping = EarlyStopping(patience=50) # --------------------- logging and tensorboard -------------------- # self.set_logger() self.set_summary_writer()
class ShuffleSelfAttentionModel(Model): def __init__(self, save_path, log_path, n_depth, d_features, d_classifier, d_output, threshold=None, stack='ShuffleSelfAttention', expansion_layer='ChannelWiseConvExpansion', mode='1d', optimizer=None, **kwargs): '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck''' ''' Arguments: mode: 1d: 1d output 2d: 2d output residual: residual output dense: dense net ''' super().__init__(save_path, log_path) self.d_output = d_output self.threshold = threshold # ----------------------------- Model ------------------------------ # stack_dict = { 'ReactionAttention': ReactionAttentionStack, 'SelfAttention': SelfAttentionStack, 'Alternate': AlternateStack, 'Parallel': ParallelStack, 'ShuffleSelfAttention': ShuffleSelfAttentionStack, 'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2, } expansion_dict = { 'LinearExpansion': LinearExpansion, 'ReduceParamLinearExpansion': ReduceParamLinearExpansion, 'ConvExpansion': ConvExpansion, 'LinearConvExpansion': LinearConvExpansion, 'ShuffleConvExpansion': ShuffleConvExpansion, 'ChannelWiseConvExpansion': ChannelWiseConvExpansion, } self.model = stack_dict[stack](expansion_dict[expansion_layer], n_depth=n_depth, d_features=d_features, mode=mode, **kwargs) # --------------------------- Classifier --------------------------- # if mode == '1d': self.classifier = LinearClassifier(d_features, d_classifier, d_output) elif mode == '2d': self.classifier = LinearClassifier(n_depth * d_features, d_classifier, d_output) else: self.classifier = None # ------------------------------ CUDA ------------------------------ # # If GPU available, move the graph to GPU(s) self.CUDA_AVAILABLE = self.check_cuda() if self.CUDA_AVAILABLE: # self.model.cuda() # self.classifier.cuda() device_ids = list(range(torch.cuda.device_count())) self.model = nn.DataParallel(self.model, device_ids) self.classifier = nn.DataParallel(self.classifier, device_ids) self.model.to('cuda') self.classifier.to('cuda') assert (next(self.model.parameters()).is_cuda) assert (next(self.classifier.parameters()).is_cuda) pass else: print('CUDA not found or not enabled, use CPU instead') # ---------------------------- Optimizer --------------------------- # self.parameters = list(self.model.parameters()) + list( self.classifier.parameters()) if optimizer == None: self.optimizer = AdamW(self.parameters, lr=0.002, betas=(0.9, 0.999), weight_decay=0.001) # ------------------------ training control ------------------------ # self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10) self.early_stopping = EarlyStopping(patience=50) # --------------------- logging and tensorboard -------------------- # self.set_logger() self.set_summary_writer() # ---------------------------- END INIT ---------------------------- # def checkpoint(self, step): checkpoint = { 'model_state_dict': self.model.state_dict(), 'classifier_state_dict': self.classifier.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'global_step': step } return checkpoint def train_epoch(self, train_dataloader, eval_dataloader, device, smothing, earlystop): ''' Epoch operation in training phase''' if device == 'cuda': assert self.CUDA_AVAILABLE # Set model and classifier training mode self.model.train() self.classifier.train() total_loss = 0 batch_counter = 0 # update param per batch for batch in tqdm(train_dataloader, mininterval=1, desc=' - (Training) ', leave=False): # training_data should be a iterable # get data from dataloader feature_1, feature_2, y = parse_data(batch, device) batch_size = len(feature_1) # forward self.optimizer.zero_grad() logits, attn = self.model(feature_1, feature_2) logits = logits.view(batch_size, -1) logits = self.classifier(logits) # Judge if it's a regression problem if self.d_output == 1: pred = logits.sigmoid() loss = mse_loss(pred, y) else: pred = logits loss = cross_entropy_loss(pred, y, smoothing=smothing) # calculate gradients loss.backward() # update parameters self.optimizer.step() # get metrics for logging acc = accuracy(pred, y, threshold=self.threshold) precision, recall, precision_avg, recall_avg = precision_recall( pred, y, self.d_output, threshold=self.threshold) total_loss += loss.item() batch_counter += 1 # training control state_dict = self.controller(batch_counter) if state_dict['step_to_print']: self.train_logger.info( '[TRAINING] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (state_dict['step'], loss, acc, precision[1], recall[1])) self.summary_writer.add_scalar('loss/train', loss, state_dict['step']) self.summary_writer.add_scalar('acc/train', acc, state_dict['step']) self.summary_writer.add_scalar('precision/train', precision[1], state_dict['step']) self.summary_writer.add_scalar('recall/train', recall[1], state_dict['step']) if state_dict['step_to_evaluate']: stop = self.val_epoch(eval_dataloader, device, state_dict['step']) state_dict['step_to_stop'] = stop if earlystop & stop: break if self.controller.current_step == self.controller.max_step: state_dict['step_to_stop'] = True break return state_dict def val_epoch(self, dataloader, device, step=0, plot=False): ''' Epoch operation in evaluation phase ''' if device == 'cuda': assert self.CUDA_AVAILABLE # Set model and classifier training mode self.model.eval() self.classifier.eval() # use evaluator to calculate the average performance evaluator = Evaluator() pred_list = [] real_list = [] with torch.no_grad(): for batch in tqdm( dataloader, mininterval=5, desc=' - (Evaluation) ', leave=False): # training_data should be a iterable # get data from dataloader feature_1, feature_2, y = parse_data(batch, device) batch_size = len(feature_1) # get logits logits, attn = self.model(feature_1, feature_2) logits = logits.view(batch_size, -1) logits = self.classifier(logits) if self.d_output == 1: pred = logits.sigmoid() loss = mse_loss(pred, y) else: pred = logits loss = cross_entropy_loss(pred, y, smoothing=False) acc = accuracy(pred, y, threshold=self.threshold) precision, recall, _, _ = precision_recall( pred, y, self.d_output, threshold=self.threshold) # feed the metrics in the evaluator evaluator(loss.item(), acc.item(), precision[1].item(), recall[1].item()) '''append the results to the predict / real list for drawing ROC or PR curve.''' if plot: pred_list += pred.tolist() real_list += y.tolist() if plot: area, precisions, recalls, thresholds = pr( pred_list, real_list) plot_pr_curve(recalls, precisions, auc=area) # get evaluation results from the evaluator loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results() self.eval_logger.info( '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (step, loss_avg, acc_avg, pre_avg, rec_avg)) self.summary_writer.add_scalar('loss/eval', loss_avg, step) self.summary_writer.add_scalar('acc/eval', acc_avg, step) self.summary_writer.add_scalar('precision/eval', pre_avg, step) self.summary_writer.add_scalar('recall/eval', rec_avg, step) state_dict = self.early_stopping(loss_avg) if state_dict['save']: checkpoint = self.checkpoint(step) self.save_model( checkpoint, self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg)) return state_dict['break'] def train(self, max_epoch, train_dataloader, eval_dataloader, device, smoothing=False, earlystop=False, save_mode='best'): assert save_mode in ['all', 'best'] # train for n epoch for epoch_i in range(max_epoch): print('[ Epoch', epoch_i, ']') # set current epoch self.controller.set_epoch(epoch_i + 1) # train for on epoch state_dict = self.train_epoch(train_dataloader, eval_dataloader, device, smoothing, earlystop) # if state_dict['step_to_stop']: # break checkpoint = self.checkpoint(state_dict['step']) self.save_model(checkpoint, self.save_path + '-step-%d' % state_dict['step']) self.train_logger.info( '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.' % (state_dict['epoch'] - 1, state_dict['batch'], state_dict['step'])) def get_predictions(self, data_loader, device, max_batches=None, activation=None): pred_list = [] real_list = [] self.model.eval() self.classifier.eval() batch_counter = 0 with torch.no_grad(): for batch in tqdm(data_loader, desc=' - (Testing) ', leave=False): feature_1, feature_2, y = parse_data(batch, device) # get logits logits, attn = self.model(feature_1, feature_2) logits = logits.view(logits.shape[0], -1) logits = self.classifier(logits) # Whether to apply activation function if activation != None: pred = activation(logits) else: pred = logits.softmax(dim=-1) pred_list += pred.tolist() real_list += y.tolist() if max_batches != None: batch_counter += 1 if batch_counter >= max_batches: break return pred_list, real_list
def train(model_path: Optional[str] = None, data_path: Optional[str] = None, version: Optional[str] = None, save_to_dir: Optional[str] = None): train_dl, val_dl, test_dl, nbatches, parser, version = init( data_path, version=version, save_to_dir=save_to_dir) core_opt = AdamW(parser.parameters(), lr=0., betas=(0.9, 0.98), eps=1e-12, weight_decay=1e-02) dec_schedule = make_cosine_schedule(max_lr=5e-04, warmup_steps=nbatches // 4, decay_over=decoder_epochs * nbatches) mutual_schedule = make_cosine_schedule_with_linear_restarts( max_lr=1e-04, warmup_steps=nbatches // 4, decay_over=mutual_epochs * nbatches, triangle_decay=40 * nbatches) st_loss = NormCrossEntropy(sep_id=parser.atom_tokenizer.sep_token_id, ignore_index=parser.atom_tokenizer.pad_token_id) sh_loss = SinkhornLoss() if model_path is not None: print('Loading checkpoint...') step_num, opt_dict, init_epoch = load_model(parser, model_path) if init_epoch < decoder_epochs: schedule = dec_schedule else: schedule = mutual_schedule opt = Scheduler(core_opt, schedule) opt.step_num = step_num opt.lr = opt.schedule(opt.step_num) opt.opt.load_state_dict(opt_dict) del opt_dict else: opt = Scheduler(core_opt, dec_schedule) init_epoch = 0 if save_to_dir is None: save_to_dir = './stored_models' for e in range(init_epoch, decoder_epochs + mutual_epochs): validate = e % 5 == 0 and e != init_epoch save = e % 5 == 0 and e != init_epoch linking_weight = 0.33 if save: print('\tSaving') torch.save( { 'model_state_dict': parser.state_dict(), 'opt_state_dict': opt.opt.state_dict(), 'step': opt.step_num, 'epoch': e }, f'{save_to_dir}/{version}/{e}.model') if e < decoder_epochs: with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream: logprint('=' * 64, [stream]) logprint(f'Pre-epoch {e}', [stream]) logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})', [stream]) supertagging_loss, linking_loss = parser.pretrain_decoder_epoch( train_dl, st_loss, sh_loss, opt, linking_weight) logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}', [stream]) logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream]) if validate: with open(f'{save_to_dir}/{version}/val_log.txt', 'a') as valstream: logprint(f'Epoch {e}', [valstream]) logprint('-' * 64, [stream, valstream]) sentence_ac, atom_ac, link_ac = parser.preval_epoch( val_dl) logprint( f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}', [stream, valstream]) logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}', [stream, valstream]) logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}', [stream, valstream]) continue elif e == decoder_epochs: opt = Scheduler(opt.opt, mutual_schedule) with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream: logprint('=' * 64, [stream]) logprint(f'Epoch {e}', [stream]) logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})', [stream]) logprint(' ' * 50 + f'LW: {linking_weight}', [stream]) logprint('-' * 64, [stream]) supertagging_loss, linking_loss = parser.train_epoch( train_dl, st_loss, sh_loss, opt, linking_weight) logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}', [stream]) logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream]) if validate: with open(f'{save_to_dir}/{version}/val_log.txt', 'a') as valstream: logprint(f'Epoch {e}', [valstream]) logprint('-' * 64, [stream, valstream]) sentence_ac, atom_ac, link_ac = parser.eval_epoch( val_dl, link=True) logprint( f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}', [stream, valstream]) logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}', [stream, valstream]) logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}', [stream, valstream]) logprint('\n', [stream])