def init_trainer_network(self):
     self.reset_parameters()
     self.gen_paramas = list(self.generator.parameters())
     self.dsc_params = list(self.discriminator.parameters())
     self.language_params = list(self.language_model.parameters())
     self.style_params = list(self.style_model.parameters())
     self.content_params = list(self.content_model.parameters())
     self.optim_language = AdamW(self.language_params,
                                 lr=self.cfg.lang_lr,
                                 betas=self.cfg.lang_betas,
                                 weight_decay=self.cfg.lang_wd)
     self.optim_generator = AdamW(self.gen_paramas,
                                  lr=self.cfg.gen_lr,
                                  betas=self.cfg.gen_betas,
                                  weight_decay=self.cfg.gen_wd)
     self.optim_discriminator = AdamW(self.dsc_params,
                                      lr=self.cfg.dsc_lr,
                                      betas=self.cfg.dsc_betas,
                                      weight_decay=self.cfg.dsc_wd)
     self.optim_style = AdamW(self.style_params,
                              lr=self.cfg.lmf_lr,
                              betas=self.cfg.lmf_betas,
                              weight_decay=self.cfg.lmf_wd)
     self.optim_content = AdamW(self.content_params,
                                lr=self.cfg.lmf_lr,
                                betas=self.cfg.lmf_betas,
                                weight_decay=self.cfg.lmf_wd)
    def init_trainer_network(self):
        images = self.forward_generator()
        _ = self.forward_discriminator(images, self.condition)
        self.reset_parameters()

        gen_paramas = list(self.generator.parameters())
        dsc_params = list(self.discriminators.parameters())
        self.optim_generator = AdamW(gen_paramas,
                                     lr=self.cfg.gen_lr,
                                     betas=self.cfg.gen_betas,
                                     weight_decay=self.cfg.gen_wd)

        self.optim_discriminator = AdamW(dsc_params,
                                         lr=self.cfg.dsc_lr,
                                         betas=self.cfg.dsc_betas,
                                         weight_decay=self.cfg.dsc_wd)

        # TODO add to yaml params for language optimizer
        feature_params = list(self.features_model.parameters())
        if self.language_model.trainable:
            feature_params = feature_params + list(
                self.language_model.parameters())
        self.optim_feature = AdamW(feature_params,
                                   lr=self.cfg.gen_lr,
                                   betas=self.cfg.gen_betas,
                                   weight_decay=self.cfg.gen_wd)
Beispiel #3
0
    def __init__(
            self, save_path, log_path, d_features, d_meta, max_length, d_classifier, n_classes, threshold=None,
            optimizer=None, **kwargs):
        '''**kwargs: n_layers, n_head, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                save_path -- model file path
                log_path -- log file path
                d_features -- how many PMs
                d_meta -- how many facility types
                max_length -- input sequence length
                d_classifier -- classifier hidden unit
                n_classes -- output dim
                threshold -- if not None, n_classes should be 1 (regression).
        '''

        super().__init__(save_path, log_path)
        self.d_output = n_classes
        self.threshold = threshold
        self.max_length = max_length

        # ----------------------------- Model ------------------------------ #

        self.model = Encoder(TimeFacilityEncoding, d_features=d_features, max_seq_length=max_length,
                                       d_meta=d_meta, **kwargs)

        # --------------------------- Classifier --------------------------- #
        self.classifier = LinearClassifier(d_features * max_length, d_classifier, n_classes)

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)
        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.count_parameters()
        self.set_logger()
        self.set_summary_writer()
Beispiel #4
0
    def configure_optimizers(self):
        opt_cfg = self.cfg['optimizer']
        lr = float(self.cfg['optimizer']['lr'])
        if opt_cfg['name'] == 'AdamW':
            optimizer = AdamW(self.model.parameters(), lr=lr, )
        elif opt_cfg['name'] == 'Adam_GCC':
            optimizer = Adam_GCC(self.model.parameters(), lr=lr)
        elif opt_cfg['name'] == 'AdamW_GCC2':
            optimizer = AdamW_GCC2(self.model.parameters(), lr=lr)

        if self.cfg['scheduler']['type'] == 'none':
            sched = None
        elif self.cfg['scheduler']['type'] == 'CosineAnnealingWarmRestarts':
            T_mult = self.cfg['scheduler']['T_mult']
            T_0 = self.cfg['scheduler']['T_0']
            eta_min = float(self.cfg['scheduler']['eta_min'])
            sched = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=-1)
        elif self.cfg['scheduler']['type'] == 'OneCycleLR':
            max_lr = float(self.cfg['scheduler']['max_lr'])
            steps_per_epoch = cfg['scheduler']['steps_per_epoch']
            epochs = cfg['scheduler']['epochs']
            sched = OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=steps_per_epoch, epochs=epochs)
        else:
            raise Exception('scheduler {} not supported'.format(self.cfg['scheduler']['type']))
        if sched is not None:
            sched = {'scheduler': sched, 'name': format(self.cfg['scheduler']['type'])}

        if sched is not None:
            return [optimizer], [sched]
        else:
            return optimizer

        return optimizer
Beispiel #5
0
def train_end_to_end(args, graph_dataset, model):
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5)
    early_stopper = EarlyStopper(args.supervised_model_patience)

    epoch = 0
    start_time = time.time()
    LOG.info('Started training supervised model.')
    while True:
        train_loss = end_to_end_train_epoch(graph_dataset, model, optimizer, args, epoch)
        val_loss, val_acc, val_f1 = end_to_end_val_epoch(graph_dataset, model, args)
        lr_scheduler.step(val_loss)

        epoch_end_time = round(time.time() - start_time, 2)
        LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: '
                 f'Training loss = {train_loss:.4f}, Validation loss = {val_loss:.4f}, '
                 f'Val accuracy = {val_acc:.4f}, Val f1 = {val_f1:.4f}')
        epoch += 1
        if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING:
            break

    # Compute metrics.
    val_y_pred = predict(graph_dataset, model, args, mode='val')
    val_y_pred = raw_output_to_prediction(val_y_pred, args.loss_fn)
    val_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.val_mask].cpu().numpy()]
    test_y_pred = predict(graph_dataset, model, args, mode='test')
    test_y_pred = raw_output_to_prediction(test_y_pred, args.loss_fn)
    test_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.test_mask].cpu().numpy()]
    log_metrics(graph_dataset, val_y_true, val_y_pred.cpu().numpy(), test_y_true,
                test_y_pred.cpu().numpy(), args.show_detailed_metrics, args.attention_radius, args.max_neighbors)

    # Save trained end-to-end model.
    torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_end_to_end.pt'))
Beispiel #6
0
def get_optimizer(parameters, optim, learning_rate, lr_decay, amsgrad, weight_decay, warmup_steps):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)
    else:
        optimizer = AdamW(parameters, lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, amsgrad=amsgrad, weight_decay=weight_decay)
    init_lr = 1e-7
    scheduler = ExponentialScheduler(optimizer, lr_decay, warmup_steps, init_lr)
    return optimizer, scheduler
Beispiel #7
0
    def __init__(self,
                 save_path,
                 log_path,
                 d_features,
                 d_out_list,
                 d_classifier,
                 d_output,
                 threshold=None,
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout'''
        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #

        self.model = SelfAttentionFeatureSelection(d_features, d_out_list)

        # --------------------------- Classifier --------------------------- #

        self.classifier = LinearClassifier(d_features, d_classifier, d_output)

        # ------------------------------ CUDA ------------------------------ #
        self.data_parallel()

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
Beispiel #8
0
def train_single_epoch(model: RedisSingleDNN, trainDataloader: DataLoader,
                       optimizer: AdamW) -> Tuple[float, float]:
    train_loss = 0.0
    train_ACC = 0
    train_steps = 0
    model.train()
    for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")):
        optimizer.zero_grad()
        knobs_with_info = batch[0].to(DEVICE)
        targets = batch[1].to(DEVICE)
        outputs = model(knobs_with_info)
        loss = F.mse_loss(outputs, targets)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_steps += 1

    return train_loss / len(trainDataloader), train_ACC
Beispiel #9
0
    def configure_optimizers(self):
        optimizer = AdamW(self.optimizer_params)

        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=int(round(self.config.kv['lr_t0'])),
            T_mult=int(round(self.config.kv['lr_f'])))
        scheduler = WarmupScheduler(optimizer,
                                    epochs=self.config.kv['lr_w'],
                                    next_scheduler=scheduler)

        return [optimizer], [scheduler]
Beispiel #10
0
def get_optimizer(learning_rate, parameters, betas, eps, amsgrad, step_decay,
                  weight_decay, warmup_steps, init_lr):
    optimizer = AdamW(parameters,
                      lr=learning_rate,
                      betas=betas,
                      eps=eps,
                      amsgrad=amsgrad,
                      weight_decay=weight_decay)
    step_decay = step_decay
    scheduler = ExponentialScheduler(optimizer, step_decay, warmup_steps,
                                     init_lr)
    return optimizer, scheduler
def get_optimisers(G: "nn.Module",
                   args: "tupperware") -> "Union[optim, lr_scheduler]":

    g_optimizer = AdamW(G.parameters(),
                        lr=args.learning_rate,
                        betas=(args.beta_1, args.beta_2))

    g_lr_scheduler = CosineAnnealingWarmRestarts(optimizer=g_optimizer,
                                                 T_0=args.T_0,
                                                 T_mult=args.T_mult,
                                                 eta_min=2e-10)

    return g_optimizer, g_lr_scheduler
def split_optimizer(model: nn.Module, cfg: dict):
    param_weight_decay, param_bias, param_other = split_params(model)
    if cfg['optimizer'] == 'Adam':
        optimizer = Adam(param_other, lr=cfg['lr'])
    elif cfg['optimizer'] == 'SGD':
        optimizer = SGD(param_other, lr=cfg['lr'], momentum=cfg['momentum'])
    elif cfg['optimizer'] == "AdamW":
        optimizer = AdamW(param_other, lr=cfg['lr'])
    else:
        raise NotImplementedError("optimizer {:s} is not support!".format(cfg['optimizer']))
    optimizer.add_param_group(
        {'params': param_weight_decay, 'weight_decay': cfg['weight_decay']})  # add pg1 with weight_decay
    optimizer.add_param_group({'params': param_bias})
    return optimizer
    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)

        try:
            if self.optimizer is None:
                if self.optimizer_f is not None:
                    self.optimizer = self.optimizer_f(
                        self.parameters()
                    )
                else:
                    self.optimizer = AdamW(
                        self.parameters(),
                        lr=self.learning_rate,
                        weight_decay=self.weight_decay
                    )

            self.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict']
            )
        except ValueError as e:
            self.logger.debug(e)
            self.logger.debug("The state of the optimizer in `TransformerModel` was not updated.")

        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx
Beispiel #14
0
    def __init__(
        self,
        config: TrainConfig,
        model: NSMCModel,
        train_data_loader: DataLoader,
        dev_data_loader: DataLoader,
        test_data_loader: DataLoader,
        logger: Logger,
        summary_writer: SummaryWriter,
    ):
        self.config = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.model = model
        self.model.to(self.device)

        self.train_data_loader = train_data_loader
        self.dev_data_loader = dev_data_loader
        self.test_data_loader = test_data_loader
        self.logger = logger
        self.summary_writer = summary_writer

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate)

        # total step 계산
        self.steps_per_epoch = len(train_data_loader)
        self.total_steps = self.steps_per_epoch * config.num_epochs
        self.warmup_steps = config.warmup_step_ratio * self.total_steps

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps)
        self.global_step = 0
Beispiel #15
0
def train_stage_two(dataset, best_model_file, model_file):
    bestaccuracy = 0.9
    device = 'cudo:0' if torch.cuda.is_available() else 'cpu'
    net = ResNet(BasicBlock, [3, 3, 4, 3]).to(device)  # [2,2,2,2]
    net.train()
    for parameter in net.parameters():
        if len(parameter.shape) > 1:
            torch.nn.init.xavier_uniform_(parameter)
    if isfile(best_model_file):
        net.load_state_dict(torch.load(best_model_file))
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    optimizer = AdamW(net.parameters(), lr=0.0001)
    scheduler = CyclicLR(optimizer,
                         0.000001,
                         0.0001,
                         step_size_up=200,
                         mode='triangular2',
                         cycle_momentum=False,
                         last_epoch=-1)
    L1 = torch.nn.L1Loss()
    BCE = torch.nn.BCEWithLogitsLoss()

    for epoch in range(50):
        running_accuracy = []
        for (images, targets) in tqdm(train_loader):
            images, targets = images.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            clsloss = BCE(outputs[:, 0], targets[:, 0])
            regloss = L1(outputs[:, 1:], targets[:, 1:])
            loss = clsloss + regloss
            cls_preds = np.greater(outputs[:, 0].cpu().detach().numpy(), 0)
            cls_truth = targets[:, 0].cpu().detach().numpy()
            correctness = np.equal(cls_preds, cls_truth).astype(int)
            accuracy = sum(correctness) / 64
            running_accuracy.append(accuracy)
            running_accuracy = running_accuracy[-10:]
            print(' clsloss ' + str(clsloss.cpu().detach().numpy())[:4] +
                  ' regloss ' + str(regloss.cpu().detach().numpy())[:4] +
                  ' accuracy ' + str(np.mean(running_accuracy)),
                  end='\r')
            if np.mean(running_accuracy) > bestaccuracy:
                bestaccuracy = np.mean(running_accuracy)
                torch.save(net.state_dict(), best_model_file)
                # print('totalloss', str(loss.detach().numpy())[:4], 'saved!', end = '\n')
            else:
                pass
                # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n')
            loss.backward()
            optimizer.step()
            scheduler.step(None)
            # if idx%5==0:
            #    print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n')
            # idx+=1
        torch.save(net.state_dict(), model_file)
        print(epoch)
Beispiel #16
0
def train_twice_epoch(model: RedisTwiceDNN, trainDataloader: DataLoader,
                      optimizer: AdamW) -> Tuple[float, float]:
    train_loss = 0.0
    train_ACC = 0
    train_steps = 0
    model.train()
    weight = [0.6, 0.4]
    for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")):
        optimizer.zero_grad()
        knobs_with_info = batch[0].to(DEVICE)
        targets = batch[1].to(DEVICE)

        outputs = model(knobs_with_info)

        loss = 0.
        for i, output in enumerate(outputs):
            loss += weight[i] * F.mse_loss(output.squeeze(1), targets[:, i])
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_steps += 1

    return train_loss / len(trainDataloader), train_ACC
Beispiel #17
0
def get_optimizer(opt, learning_rate, parameters, hyper1, hyper2, eps, rebound,
                  lr_decay, decay_rate, milestone, weight_decay, weight_decay_type,
                  warmup_updates, init_lr, last_lr, num_epochs, world_size):
    if opt == 'sgd':
        optimizer = SGD(parameters, lr=learning_rate, momentum=hyper1, weight_decay=weight_decay, nesterov=True)
        opt = 'momentum=%.1f, ' % (hyper1)
        weight_decay_type = 'L2'
    elif opt == 'radamw':
        optimizer = RAdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    elif opt == 'adamw':
        optimizer = AdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    elif opt == 'adabelief':
        optimizer = AdaBelief(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps,
                              weight_decay=weight_decay, weight_decay_type=weight_decay_type)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
    elif opt == 'apollo':
        optimizer = Apollo(parameters, lr=learning_rate, beta=hyper1, eps=eps, rebound=rebound,
                           warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay,
                           weight_decay_type=weight_decay_type)
        opt = 'beta=%.1f, eps=%.1e, rebound=%s, ' % (hyper1, eps, rebound)
    elif opt == 'adahessian':
        optimizer = AdaHessian(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps,
                               warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay, num_threads=world_size)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    else:
        raise ValueError('unknown optimizer: {}'.format(opt))

    if lr_decay == 'exp':
        opt = opt + 'lr decay={}, decay rate={:.3f}, '.format(lr_decay, decay_rate)
        scheduler = ExponentialLR(optimizer, decay_rate)
    elif lr_decay == 'milestone':
        opt = opt + 'lr decay={} {}, decay rate={:.3f}, '.format(lr_decay, milestone, decay_rate)
        scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=decay_rate)
    elif lr_decay == 'cosine':
        opt = opt + 'lr decay={}, lr_min={}, '.format(lr_decay, last_lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=last_lr)
    else:
        raise ValueError('unknown lr decay: {}'.format(lr_decay))

    opt += 'warmup={}, init_lr={:.1e}, wd={:.1e} ({})'.format(warmup_updates, init_lr, weight_decay, weight_decay_type)
    return optimizer, scheduler, opt
Beispiel #18
0
def train(sqlite_file, model_selection_file, model_file):
    """

    Selects the best model/training parameter combination from a model selection and
    trains a final image classifier on the entire annotated part of the image database.

    SQLITE_FILE: An annotated image database (see create-database).
    MODEL_SELECTION_FILE: Results of a model selection created by the "model-selection" tool.
    MODEL_FILE: Store the finally trained image classification model in this file.

    """

    X, class_to_label, label_to_class = load_ground_truth(sqlite_file)

    X['file'] = X['file'].astype(str)
    y = X['class'].astype(int)

    batch_size, decrease_epochs, decrease_factor, epochs, model_name, num_trained, start_lr = load_model_selection(
        model_selection_file)

    model, device, fit_transform, predict_transform, logits_func = \
        load_pretrained_model(model_name, len(label_to_class), num_trained)

    # optimizer = optim.SGD(model_ft.parameters(), lr=start_lr, momentum=momentum)
    optimizer = AdamW(model.parameters(), lr=start_lr)

    sched = lr_scheduler.StepLR(optimizer,
                                step_size=decrease_epochs,
                                gamma=decrease_factor)

    estimator = ImageClassifier(model=model,
                                model_weights=copy.deepcopy(
                                    model.state_dict()),
                                device=device,
                                criterion=nn.CrossEntropyLoss(),
                                optimizer=optimizer,
                                scheduler=sched,
                                fit_transform=fit_transform,
                                predict_transform=predict_transform,
                                batch_size=batch_size,
                                logits_func=logits_func)

    for _ in range(0, epochs):
        estimator.fit(X, y)

    torch.save(model.state_dict(), model_file)
Beispiel #19
0
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        if self.lr_warmup_steps is None:
            return optimizer

        scheduler = InverseSquareRootLR(optimizer, self.lr_warmup_steps)
        return (
            [optimizer],
            [
                {
                    'scheduler': scheduler,
                    'interval': 'step',
                    'frequency': 1,
                    'reduce_on_plateau': False,
                    'monitor': 'label_smoothed_val_loss' if self.optimize_on_smoothed_loss else 'val_loss',
                }
            ]
        )
Beispiel #20
0
def train_unsupervised(args, graph_dataset, model):
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5)
    early_stopper = EarlyStopper(args.unsupervised_model_patience)

    start_time = time.time()
    LOG.info("Started training unsupervised model.")
    for epoch in range(args.epochs):
        train_loss, train_count = node_encoder_train_epoch(graph_dataset, model, optimizer, epoch, args)
        val_loss, val_count = node_encoder_val_epoch(graph_dataset, model, epoch, args)
        lr_scheduler.step(val_loss)

        epoch_end_time = round(time.time() - start_time, 2)
        LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: Training loss [over {train_count} nodes] = {train_loss:.4f}, '
                 f'Validation loss [over {val_count} nodes] = {val_loss:.4f}')

        if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING:
            break

    LOG.info("Unsupervised training complete.")
    torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_unsupervised.pt'))
Beispiel #21
0
    def _fit(self, modules: nn.ModuleDict, train_dl: DeviceDataLoader,
             valid_dl: DeviceDataLoader):
        r""" Fits \p modules' learners to the training and validation \p DataLoader objects """
        self._configure_fit_vars(modules)

        for mod_name, module in modules.items():
            lr = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.LEARNING_RATE)
            wd = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.WEIGHT_DECAY)
            is_lin_ff = config.DATASET.is_synthetic(
            ) and module.module.num_hidden_layers == 0
            if is_lin_ff:
                module.optim = LBFGS(module.parameters(), lr=lr)
            else:
                module.optim = AdamW(module.parameters(),
                                     lr=lr,
                                     weight_decay=wd,
                                     amsgrad=True)
            logging.debug(
                f"{mod_name} Optimizer: {module.optim.__class__.__name__}")

        for ep in range(1, config.NUM_EPOCH + 1):
            # noinspection PyUnresolvedReferences
            for _, module in modules.items():
                module.epoch_start()

            for batch in train_dl:
                for _, module in modules.items():
                    module.process_batch(batch)

            for _, module in modules.items():
                module.calc_valid_loss(valid_dl)
            self._log_epoch(ep, modules)
        self._restore_best_model(modules)
        self.eval()
Beispiel #22
0
 def build_optimizer(self, args: ClassifierArgs, **kwargs):
     no_decay = ['bias', 'LayerNorm.weight']
     optimizer_grouped_parameters = [
         {
             "params": [
                 p for n, p in self.model.named_parameters()
                 if not any(nd in n for nd in no_decay)
             ],
             "weight_decay":
             args.weight_decay,
         },
         {
             "params": [
                 p for n, p in self.model.named_parameters()
                 if any(nd in n for nd in no_decay)
             ],
             "weight_decay":
             0.0
         },
     ]
     optimizer = AdamW(optimizer_grouped_parameters,
                       lr=args.learning_rate,
                       eps=args.adam_epsilon)
     return optimizer
Beispiel #23
0
def create_optim(params, args=None, optim_name=None, lr=None):
    if args is None:
        assert optim_name is not None and lr is not None
    else:
        assert optim_name is None and lr is None
        optim_name = args.optim
        lr = args.lr
    if optim_name == 'sgd':
        return SGD(params,
                   lr=lr,
                   momentum=0.9,
                   weight_decay=5e-4,
                   nesterov=True)
    elif optim_name == 'adam':
        return Adam(params, lr=lr, weight_decay=0)
    elif optim_name == 'adamw':
        return AdamW(params, lr=lr, weight_decay=1e-2)
    elif optim_name == 'amsgrad':
        return Adam(params, lr=lr, weight_decay=0, amsgrad=True)
    elif optim_name == 'rmsprop':
        return RMSprop(params, lr=lr, momentum=0.9, weight_decay=0)
    else:
        raise NotImplementedError(
            'Unsupported optimizer_memory: {}'.format(optim_name))
Beispiel #24
0
 def configure_optimizers(self):
     return AdamW(params=self.model.parameters(),
                  lr=2e-5,
                  eps=1e-6,
                  correct_bias=False)
Beispiel #25
0
def main():
    def evaluate_accuracy(data_loader, prefix: str, save: bool = False):
        std_transform.eval()
        model.eval()
        pbar = tqdm(data_loader,
                    desc=prefix,
                    leave=True,
                    total=len(data_loader))
        num_corr = 0
        num_tot = 0
        for idx, batch in enumerate(pbar):
            batch = batch.to(device)
            scores = model(zmuv_transform(std_transform(batch.audio_data)),
                           std_transform.compute_lengths(batch.lengths))
            num_tot += scores.size(0)
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            num_corr += (scores.max(1)[1] == labels).float().sum().item()
            acc = num_corr / num_tot
            pbar.set_postfix(accuracy=f'{acc:.4}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/acc', acc, epoch_idx)
            ws.increment_model(model, acc / 10)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--model', type=str, choices=model_names(), default='las'),
        opt('--workspace',
            type=str,
            default=str(Path('workspaces') / 'default')),
        opt('--load-weights', action='store_true'),
        opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = GoogleSpeechCommandsDatasetLoader()
    sr = SETTINGS.audio.sample_rate
    ds_kwargs = dict(sr=sr, mono=SETTINGS.audio.use_mono)
    train_ds, dev_ds, test_ds = loader.load_splits(
        SETTINGS.dataset.dataset_path, **ds_kwargs)

    sr = SETTINGS.audio.sample_rate
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)
    batchifier = partial(batchify, label_provider=lambda x: x.label)
    truncater = partial(truncate_length,
                        length=int(SETTINGS.training.max_window_size_seconds *
                                   sr))
    train_comp = compose(truncater,
                         TimeshiftTransform().train(),
                         NoiseTransform().train(), batchifier)
    prep_dl = StandardAudioDataLoaderBuilder(train_ds,
                                             collate_fn=batchifier).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(
        train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)
    dev_dl = StandardAudioDataLoaderBuilder(
        dev_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)
    test_dl = StandardAudioDataLoaderBuilder(
        test_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)

    model = find_model(args.model)().to(device)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params,
                      SETTINGS.training.learning_rate,
                      weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')
    criterion = nn.CrossEntropyLoss()

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path /
                                                      'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(
            dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=True)
    if args.eval:
        ws.load_model(model, best=True)
        evaluate_accuracy(dev_dl, 'Dev')
        evaluate_accuracy(test_dl, 'Test')
        return

    ws.write_args(args)
    ws.write_setting(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs,
                            position=0,
                            leave=True):
        model.train()
        std_transform.train()
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        for batch in pbar:
            batch.to(device)
            audio_data = zmuv_transform(std_transform(batch.audio_data))
            scores = model(audio_data,
                           std_transform.compute_lengths(batch.lengths))
            optimizer.zero_grad()
            model.zero_grad()
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            writer.add_scalar('Training/Loss', loss.item(), epoch_idx)

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay
        evaluate_accuracy(dev_dl, 'Dev', save=True)
    evaluate_accuracy(test_dl, 'Test')
Beispiel #26
0
def main():
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        if use_frame:
            engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000),
                                          int(SETTINGS.training.eval_stride_size_seconds * 1000),
                                          SETTINGS.audio.sample_rate,
                                          model,
                                          zmuv_transform,
                                          negative_label=ctx.negative_label,
                                          coloring=ctx.coloring)
        else:
            engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                             model,
                                             zmuv_transform,
                                             negative_label=ctx.negative_label,
                                             coloring=ctx.coloring,
                                             blank_idx=ctx.blank_label)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = engine.infer(audio_data)
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n')
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)
        if args.eval:
            threshold = engine.threshold
            with (ws.path / (str(round(threshold, 2)) + '_results.csv') ).open('a') as f:
                f.write(f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n')

    def do_evaluate():
        evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True)
        evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_dev_pos_ds, 'Dev noisy positive', positive_set=True, mixer=dev_mixer)
            evaluate_engine(ww_dev_neg_ds, 'Dev noisy negative', positive_set=False, mixer=dev_mixer)
        evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True)
        evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_test_pos_ds, 'Test noisy positive', positive_set=True, mixer=test_mixer)
            evaluate_engine(ww_test_neg_ds, 'Test noisy negative', positive_set=False, mixer=test_mixer)

    apb = ArgumentParserBuilder()
    apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'),
                    opt('--workspace', type=str, default=str(Path('workspaces') / 'default')),
                    opt('--load-weights', action='store_true'),
                    opt('--load-last', action='store_true'),
                    opt('--no-dev-per-epoch', action='store_false', dest='dev_per_epoch'),
                    opt('--dataset-paths', '-i', type=str, nargs='+', default=[SETTINGS.dataset.dataset_path]),
                    opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame)
    if use_frame:
        batchifier = WakeWordFrameBatchifier(ctx.negative_label,
                                             window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000))
        criterion = nn.CrossEntropyLoss()
    else:
        tokenizer = WakeWordTokenizer(ctx.vocab, ignore_oov=False)
        batchifier = AudioSequenceBatchifier(ctx.negative_label, tokenizer)
        criterion = nn.CTCLoss(ctx.blank_label)

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=ctx.labeler)

    ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs)
    for ds_path in args.dataset_paths:
        ds_path = Path(ds_path)
        train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
        ww_train_ds.extend(train_ds)
        ww_dev_ds.extend(dev_ds)
        ww_test_ds.extend(test_ds)
    print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)
    ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)

    print_stats(f'Dev dataset', ww_dev_pos_ds, ww_dev_neg_ds)
    print_stats(f'Test dataset', ww_test_pos_ds, ww_test_neg_ds)
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)

    train_comp = (NoiseTransform().train(), batchifier)

    if SETTINGS.training.use_noise_dataset:
        noise_ds = RecursiveNoiseDatasetLoader().load(Path(SETTINGS.raw_dataset.noise_dataset_path),
                                                      sr=SETTINGS.audio.sample_rate,
                                                      mono=SETTINGS.audio.use_mono)
        logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.')
        noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
        noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))
        train_comp = (DatasetMixer(noise_ds_train).train(),) + train_comp
        dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False)
        test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False)
    train_comp = compose(*train_comp)

    prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=batchify).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)

    model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).streaming()
    if SETTINGS.training.convert_static:
        model = ConvertedStaticModel(model, 40, 10)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=not args.load_last)
    if args.eval:
        ws.load_model(model, best=not args.load_last)
        do_evaluate()
        return

    ws.write_args(args)
    ws.write_settings(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True):
        model.train()
        std_transform.train()
        model.streaming_state = None
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        total_loss = torch.Tensor([0.0]).to(device)
        for batch in pbar:
            batch.to(device)
            if use_frame:
                scores = model(zmuv_transform(std_transform(batch.audio_data)),
                               std_transform.compute_lengths(batch.lengths))
                loss = criterion(scores, batch.labels)
            else:
                lengths = std_transform.compute_lengths(batch.audio_lengths)
                scores = model(zmuv_transform(std_transform(batch.audio_data)), lengths)
                scores = F.log_softmax(scores, -1)  # [num_frames x batch_size x num_labels]
                lengths = torch.tensor([model.compute_length(x.item()) for x in lengths]).to(device)
                loss = criterion(scores, batch.labels, lengths, batch.label_lengths)
            optimizer.zero_grad()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            with torch.no_grad():
                total_loss += loss

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay

        mean = total_loss / len(train_dl)
        writer.add_scalar('Training/Loss', mean.item(), epoch_idx)
        writer.add_scalar('Training/LearningRate', group['lr'], epoch_idx)

        if args.dev_per_epoch:
            evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True, save=True, write_errors=False)

    do_evaluate()
Beispiel #27
0
    if args.encoder == 'bert':
        tokenizer = get_bert_tokenizer(
            args.pretrain_checkpoint,
            add_tokens=['[EOT]']
        )
        two_tower_model = build_biencoder_model(
            model_tokenizer=tokenizer,
            args=args, aggregation='cls'
        )

    else:
        raise Exception

    optimizer = AdamW(
        [{'params': two_tower_model.parameters(), 'initial_lr': args.lr}],
        lr=args.lr
    )
    if args.best_acc > 0:
        two_tower_model, likelihood_criterion = get_saved_model_and_optimizer(
            model=two_tower_model,
            optimizer=optimizer,
            checkpoint_path=os.path.join(
                args.model_checkpoint,
                "%s_acc_%.5f" % (args.task_name, args.best_acc)
            ))
        # two_tower_model.load_state_dict(torch.load(os.path.join(
        #     args.model_checkpoint, "%s_acc_%.5f" % (args.task_name,
        #                                             args.best_acc)
        # )))

    if args.distributed and args.device == 'cuda':
Beispiel #28
0
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_classifier,
                 d_output,
                 threshold=None,
                 stack='ShuffleSelfAttention',
                 expansion_layer='ChannelWiseConvExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                mode:   1d:         1d output
                        2d:         2d output
                        residual:   residual output
                        dense:      dense net
        
        '''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack,
            'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2,
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }

        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       mode=mode,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            # self.model.cuda()
            # self.classifier.cuda()

            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
Beispiel #29
0
class ShuffleSelfAttentionModel(Model):
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_classifier,
                 d_output,
                 threshold=None,
                 stack='ShuffleSelfAttention',
                 expansion_layer='ChannelWiseConvExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                mode:   1d:         1d output
                        2d:         2d output
                        residual:   residual output
                        dense:      dense net
        
        '''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack,
            'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2,
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }

        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       mode=mode,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            # self.model.cuda()
            # self.classifier.cuda()

            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
        # ---------------------------- END INIT ---------------------------- #

    def checkpoint(self, step):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'classifier_state_dict': self.classifier.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': step
        }
        return checkpoint

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing,
                    earlystop):
        ''' Epoch operation in training phase'''

        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(train_dataloader,
                          mininterval=1,
                          desc='  - (Training)   ',
                          leave=False):  # training_data should be a iterable

            # get data from dataloader

            feature_1, feature_2, y = parse_data(batch, device)

            batch_size = len(feature_1)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(feature_1, feature_2)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(
                pred, y, self.d_output, threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                    % (state_dict['step'], loss, acc, precision[1], recall[1]))
                self.summary_writer.add_scalar('loss/train', loss,
                                               state_dict['step'])
                self.summary_writer.add_scalar('acc/train', acc,
                                               state_dict['step'])
                self.summary_writer.add_scalar('precision/train', precision[1],
                                               state_dict['step'])
                self.summary_writer.add_scalar('recall/train', recall[1],
                                               state_dict['step'])

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device,
                                      state_dict['step'])
                state_dict['step_to_stop'] = stop

                if earlystop & stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader,
                    mininterval=5,
                    desc='  - (Evaluation)   ',
                    leave=False):  # training_data should be a iterable

                # get data from dataloader
                feature_1, feature_2, y = parse_data(batch, device)

                batch_size = len(feature_1)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(
                    pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(),
                          recall[1].item())
                '''append the results to the predict / real list for drawing ROC or PR curve.'''
                if plot:
                    pred_list += pred.tolist()
                    real_list += y.tolist()

            if plot:
                area, precisions, recalls, thresholds = pr(
                    pred_list, real_list)
                plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.eval_logger.info(
                '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                % (step, loss_avg, acc_avg, pre_avg, rec_avg))
            self.summary_writer.add_scalar('loss/eval', loss_avg, step)
            self.summary_writer.add_scalar('acc/eval', acc_avg, step)
            self.summary_writer.add_scalar('precision/eval', pre_avg, step)
            self.summary_writer.add_scalar('recall/eval', rec_avg, step)

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = self.checkpoint(step)
                self.save_model(
                    checkpoint,
                    self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg))

            return state_dict['break']

    def train(self,
              max_epoch,
              train_dataloader,
              eval_dataloader,
              device,
              smoothing=False,
              earlystop=False,
              save_mode='best'):

        assert save_mode in ['all', 'best']
        # train for n epoch
        for epoch_i in range(max_epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader,
                                          device, smoothing, earlystop)

            # if state_dict['step_to_stop']:
            #     break

        checkpoint = self.checkpoint(state_dict['step'])

        self.save_model(checkpoint,
                        self.save_path + '-step-%d' % state_dict['step'])

        self.train_logger.info(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.'
            %
            (state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self,
                        data_loader,
                        device,
                        max_batches=None,
                        activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(data_loader,
                              desc='  - (Testing)   ',
                              leave=False):

                feature_1, feature_2, y = parse_data(batch, device)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits.softmax(dim=-1)
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list
Beispiel #30
0
def train(model_path: Optional[str] = None,
          data_path: Optional[str] = None,
          version: Optional[str] = None,
          save_to_dir: Optional[str] = None):

    train_dl, val_dl, test_dl, nbatches, parser, version = init(
        data_path, version=version, save_to_dir=save_to_dir)
    core_opt = AdamW(parser.parameters(),
                     lr=0.,
                     betas=(0.9, 0.98),
                     eps=1e-12,
                     weight_decay=1e-02)
    dec_schedule = make_cosine_schedule(max_lr=5e-04,
                                        warmup_steps=nbatches // 4,
                                        decay_over=decoder_epochs * nbatches)
    mutual_schedule = make_cosine_schedule_with_linear_restarts(
        max_lr=1e-04,
        warmup_steps=nbatches // 4,
        decay_over=mutual_epochs * nbatches,
        triangle_decay=40 * nbatches)
    st_loss = NormCrossEntropy(sep_id=parser.atom_tokenizer.sep_token_id,
                               ignore_index=parser.atom_tokenizer.pad_token_id)
    sh_loss = SinkhornLoss()

    if model_path is not None:
        print('Loading checkpoint...')
        step_num, opt_dict, init_epoch = load_model(parser, model_path)
        if init_epoch < decoder_epochs:
            schedule = dec_schedule
        else:
            schedule = mutual_schedule

        opt = Scheduler(core_opt, schedule)
        opt.step_num = step_num
        opt.lr = opt.schedule(opt.step_num)
        opt.opt.load_state_dict(opt_dict)
        del opt_dict
    else:
        opt = Scheduler(core_opt, dec_schedule)
        init_epoch = 0

    if save_to_dir is None:
        save_to_dir = './stored_models'

    for e in range(init_epoch, decoder_epochs + mutual_epochs):
        validate = e % 5 == 0 and e != init_epoch
        save = e % 5 == 0 and e != init_epoch
        linking_weight = 0.33

        if save:
            print('\tSaving')
            torch.save(
                {
                    'model_state_dict': parser.state_dict(),
                    'opt_state_dict': opt.opt.state_dict(),
                    'step': opt.step_num,
                    'epoch': e
                }, f'{save_to_dir}/{version}/{e}.model')

        if e < decoder_epochs:
            with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream:
                logprint('=' * 64, [stream])
                logprint(f'Pre-epoch {e}', [stream])
                logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})',
                         [stream])
                supertagging_loss, linking_loss = parser.pretrain_decoder_epoch(
                    train_dl, st_loss, sh_loss, opt, linking_weight)
                logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}',
                         [stream])
                logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream])
                if validate:
                    with open(f'{save_to_dir}/{version}/val_log.txt',
                              'a') as valstream:
                        logprint(f'Epoch {e}', [valstream])
                        logprint('-' * 64, [stream, valstream])
                        sentence_ac, atom_ac, link_ac = parser.preval_epoch(
                            val_dl)
                        logprint(
                            f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}',
                            [stream, valstream])
                        logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}',
                                 [stream, valstream])
                        logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}',
                                 [stream, valstream])
                continue
        elif e == decoder_epochs:
            opt = Scheduler(opt.opt, mutual_schedule)

        with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream:
            logprint('=' * 64, [stream])
            logprint(f'Epoch {e}', [stream])
            logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})', [stream])
            logprint(' ' * 50 + f'LW: {linking_weight}', [stream])
            logprint('-' * 64, [stream])
            supertagging_loss, linking_loss = parser.train_epoch(
                train_dl, st_loss, sh_loss, opt, linking_weight)
            logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}',
                     [stream])
            logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream])
            if validate:
                with open(f'{save_to_dir}/{version}/val_log.txt',
                          'a') as valstream:
                    logprint(f'Epoch {e}', [valstream])
                    logprint('-' * 64, [stream, valstream])
                    sentence_ac, atom_ac, link_ac = parser.eval_epoch(
                        val_dl, link=True)
                    logprint(
                        f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}',
                        [stream, valstream])
                    logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}',
                             [stream, valstream])
                    logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}',
                             [stream, valstream])
            logprint('\n', [stream])