Ejemplo n.º 1
0
def train_stage_two(dataset, best_model_file, model_file):
    bestaccuracy = 0.9
    device = 'cudo:0' if torch.cuda.is_available() else 'cpu'
    net = ResNet(BasicBlock, [3, 3, 4, 3]).to(device)  # [2,2,2,2]
    net.train()
    for parameter in net.parameters():
        if len(parameter.shape) > 1:
            torch.nn.init.xavier_uniform_(parameter)
    if isfile(best_model_file):
        net.load_state_dict(torch.load(best_model_file))
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    optimizer = AdamW(net.parameters(), lr=0.0001)
    scheduler = CyclicLR(optimizer,
                         0.000001,
                         0.0001,
                         step_size_up=200,
                         mode='triangular2',
                         cycle_momentum=False,
                         last_epoch=-1)
    L1 = torch.nn.L1Loss()
    BCE = torch.nn.BCEWithLogitsLoss()

    for epoch in range(50):
        running_accuracy = []
        for (images, targets) in tqdm(train_loader):
            images, targets = images.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            clsloss = BCE(outputs[:, 0], targets[:, 0])
            regloss = L1(outputs[:, 1:], targets[:, 1:])
            loss = clsloss + regloss
            cls_preds = np.greater(outputs[:, 0].cpu().detach().numpy(), 0)
            cls_truth = targets[:, 0].cpu().detach().numpy()
            correctness = np.equal(cls_preds, cls_truth).astype(int)
            accuracy = sum(correctness) / 64
            running_accuracy.append(accuracy)
            running_accuracy = running_accuracy[-10:]
            print(' clsloss ' + str(clsloss.cpu().detach().numpy())[:4] +
                  ' regloss ' + str(regloss.cpu().detach().numpy())[:4] +
                  ' accuracy ' + str(np.mean(running_accuracy)),
                  end='\r')
            if np.mean(running_accuracy) > bestaccuracy:
                bestaccuracy = np.mean(running_accuracy)
                torch.save(net.state_dict(), best_model_file)
                # print('totalloss', str(loss.detach().numpy())[:4], 'saved!', end = '\n')
            else:
                pass
                # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n')
            loss.backward()
            optimizer.step()
            scheduler.step(None)
            # if idx%5==0:
            #    print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n')
            # idx+=1
        torch.save(net.state_dict(), model_file)
        print(epoch)
Ejemplo n.º 2
0
def train_single_epoch(model: RedisSingleDNN, trainDataloader: DataLoader,
                       optimizer: AdamW) -> Tuple[float, float]:
    train_loss = 0.0
    train_ACC = 0
    train_steps = 0
    model.train()
    for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")):
        optimizer.zero_grad()
        knobs_with_info = batch[0].to(DEVICE)
        targets = batch[1].to(DEVICE)
        outputs = model(knobs_with_info)
        loss = F.mse_loss(outputs, targets)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_steps += 1

    return train_loss / len(trainDataloader), train_ACC
Ejemplo n.º 3
0
def train_twice_epoch(model: RedisTwiceDNN, trainDataloader: DataLoader,
                      optimizer: AdamW) -> Tuple[float, float]:
    train_loss = 0.0
    train_ACC = 0
    train_steps = 0
    model.train()
    weight = [0.6, 0.4]
    for _, batch in enumerate(tqdm(trainDataloader, desc="Iteration")):
        optimizer.zero_grad()
        knobs_with_info = batch[0].to(DEVICE)
        targets = batch[1].to(DEVICE)

        outputs = model(knobs_with_info)

        loss = 0.
        for i, output in enumerate(outputs):
            loss += weight[i] * F.mse_loss(output.squeeze(1), targets[:, i])
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_steps += 1

    return train_loss / len(trainDataloader), train_ACC
Ejemplo n.º 4
0
def main():
    def evaluate_accuracy(data_loader, prefix: str, save: bool = False):
        std_transform.eval()
        model.eval()
        pbar = tqdm(data_loader,
                    desc=prefix,
                    leave=True,
                    total=len(data_loader))
        num_corr = 0
        num_tot = 0
        for idx, batch in enumerate(pbar):
            batch = batch.to(device)
            scores = model(zmuv_transform(std_transform(batch.audio_data)),
                           std_transform.compute_lengths(batch.lengths))
            num_tot += scores.size(0)
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            num_corr += (scores.max(1)[1] == labels).float().sum().item()
            acc = num_corr / num_tot
            pbar.set_postfix(accuracy=f'{acc:.4}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/acc', acc, epoch_idx)
            ws.increment_model(model, acc / 10)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--model', type=str, choices=model_names(), default='las'),
        opt('--workspace',
            type=str,
            default=str(Path('workspaces') / 'default')),
        opt('--load-weights', action='store_true'),
        opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = GoogleSpeechCommandsDatasetLoader()
    sr = SETTINGS.audio.sample_rate
    ds_kwargs = dict(sr=sr, mono=SETTINGS.audio.use_mono)
    train_ds, dev_ds, test_ds = loader.load_splits(
        SETTINGS.dataset.dataset_path, **ds_kwargs)

    sr = SETTINGS.audio.sample_rate
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)
    batchifier = partial(batchify, label_provider=lambda x: x.label)
    truncater = partial(truncate_length,
                        length=int(SETTINGS.training.max_window_size_seconds *
                                   sr))
    train_comp = compose(truncater,
                         TimeshiftTransform().train(),
                         NoiseTransform().train(), batchifier)
    prep_dl = StandardAudioDataLoaderBuilder(train_ds,
                                             collate_fn=batchifier).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(
        train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)
    dev_dl = StandardAudioDataLoaderBuilder(
        dev_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)
    test_dl = StandardAudioDataLoaderBuilder(
        test_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)

    model = find_model(args.model)().to(device)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params,
                      SETTINGS.training.learning_rate,
                      weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')
    criterion = nn.CrossEntropyLoss()

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path /
                                                      'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(
            dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=True)
    if args.eval:
        ws.load_model(model, best=True)
        evaluate_accuracy(dev_dl, 'Dev')
        evaluate_accuracy(test_dl, 'Test')
        return

    ws.write_args(args)
    ws.write_setting(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs,
                            position=0,
                            leave=True):
        model.train()
        std_transform.train()
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        for batch in pbar:
            batch.to(device)
            audio_data = zmuv_transform(std_transform(batch.audio_data))
            scores = model(audio_data,
                           std_transform.compute_lengths(batch.lengths))
            optimizer.zero_grad()
            model.zero_grad()
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            writer.add_scalar('Training/Loss', loss.item(), epoch_idx)

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay
        evaluate_accuracy(dev_dl, 'Dev', save=True)
    evaluate_accuracy(test_dl, 'Test')
Ejemplo n.º 5
0
def main():
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        if use_frame:
            engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000),
                                          int(SETTINGS.training.eval_stride_size_seconds * 1000),
                                          SETTINGS.audio.sample_rate,
                                          model,
                                          zmuv_transform,
                                          negative_label=ctx.negative_label,
                                          coloring=ctx.coloring)
        else:
            engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                             model,
                                             zmuv_transform,
                                             negative_label=ctx.negative_label,
                                             coloring=ctx.coloring,
                                             blank_idx=ctx.blank_label)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = engine.infer(audio_data)
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n')
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)
        if args.eval:
            threshold = engine.threshold
            with (ws.path / (str(round(threshold, 2)) + '_results.csv') ).open('a') as f:
                f.write(f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n')

    def do_evaluate():
        evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True)
        evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_dev_pos_ds, 'Dev noisy positive', positive_set=True, mixer=dev_mixer)
            evaluate_engine(ww_dev_neg_ds, 'Dev noisy negative', positive_set=False, mixer=dev_mixer)
        evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True)
        evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_test_pos_ds, 'Test noisy positive', positive_set=True, mixer=test_mixer)
            evaluate_engine(ww_test_neg_ds, 'Test noisy negative', positive_set=False, mixer=test_mixer)

    apb = ArgumentParserBuilder()
    apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'),
                    opt('--workspace', type=str, default=str(Path('workspaces') / 'default')),
                    opt('--load-weights', action='store_true'),
                    opt('--load-last', action='store_true'),
                    opt('--no-dev-per-epoch', action='store_false', dest='dev_per_epoch'),
                    opt('--dataset-paths', '-i', type=str, nargs='+', default=[SETTINGS.dataset.dataset_path]),
                    opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame)
    if use_frame:
        batchifier = WakeWordFrameBatchifier(ctx.negative_label,
                                             window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000))
        criterion = nn.CrossEntropyLoss()
    else:
        tokenizer = WakeWordTokenizer(ctx.vocab, ignore_oov=False)
        batchifier = AudioSequenceBatchifier(ctx.negative_label, tokenizer)
        criterion = nn.CTCLoss(ctx.blank_label)

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=ctx.labeler)

    ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs)
    for ds_path in args.dataset_paths:
        ds_path = Path(ds_path)
        train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
        ww_train_ds.extend(train_ds)
        ww_dev_ds.extend(dev_ds)
        ww_test_ds.extend(test_ds)
    print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)
    ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)

    print_stats(f'Dev dataset', ww_dev_pos_ds, ww_dev_neg_ds)
    print_stats(f'Test dataset', ww_test_pos_ds, ww_test_neg_ds)
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)

    train_comp = (NoiseTransform().train(), batchifier)

    if SETTINGS.training.use_noise_dataset:
        noise_ds = RecursiveNoiseDatasetLoader().load(Path(SETTINGS.raw_dataset.noise_dataset_path),
                                                      sr=SETTINGS.audio.sample_rate,
                                                      mono=SETTINGS.audio.use_mono)
        logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.')
        noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
        noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))
        train_comp = (DatasetMixer(noise_ds_train).train(),) + train_comp
        dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False)
        test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False)
    train_comp = compose(*train_comp)

    prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=batchify).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)

    model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).streaming()
    if SETTINGS.training.convert_static:
        model = ConvertedStaticModel(model, 40, 10)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=not args.load_last)
    if args.eval:
        ws.load_model(model, best=not args.load_last)
        do_evaluate()
        return

    ws.write_args(args)
    ws.write_settings(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True):
        model.train()
        std_transform.train()
        model.streaming_state = None
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        total_loss = torch.Tensor([0.0]).to(device)
        for batch in pbar:
            batch.to(device)
            if use_frame:
                scores = model(zmuv_transform(std_transform(batch.audio_data)),
                               std_transform.compute_lengths(batch.lengths))
                loss = criterion(scores, batch.labels)
            else:
                lengths = std_transform.compute_lengths(batch.audio_lengths)
                scores = model(zmuv_transform(std_transform(batch.audio_data)), lengths)
                scores = F.log_softmax(scores, -1)  # [num_frames x batch_size x num_labels]
                lengths = torch.tensor([model.compute_length(x.item()) for x in lengths]).to(device)
                loss = criterion(scores, batch.labels, lengths, batch.label_lengths)
            optimizer.zero_grad()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            with torch.no_grad():
                total_loss += loss

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay

        mean = total_loss / len(train_dl)
        writer.add_scalar('Training/Loss', mean.item(), epoch_idx)
        writer.add_scalar('Training/LearningRate', group['lr'], epoch_idx)

        if args.dev_per_epoch:
            evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True, save=True, write_errors=False)

    do_evaluate()
Ejemplo n.º 6
0
class ShuffleSelfAttentionModel(Model):
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_classifier,
                 d_output,
                 threshold=None,
                 stack='ShuffleSelfAttention',
                 expansion_layer='ChannelWiseConvExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                mode:   1d:         1d output
                        2d:         2d output
                        residual:   residual output
                        dense:      dense net
        
        '''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack,
            'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2,
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }

        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       mode=mode,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            # self.model.cuda()
            # self.classifier.cuda()

            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
        # ---------------------------- END INIT ---------------------------- #

    def checkpoint(self, step):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'classifier_state_dict': self.classifier.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': step
        }
        return checkpoint

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing,
                    earlystop):
        ''' Epoch operation in training phase'''

        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(train_dataloader,
                          mininterval=1,
                          desc='  - (Training)   ',
                          leave=False):  # training_data should be a iterable

            # get data from dataloader

            feature_1, feature_2, y = parse_data(batch, device)

            batch_size = len(feature_1)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(feature_1, feature_2)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(
                pred, y, self.d_output, threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                    % (state_dict['step'], loss, acc, precision[1], recall[1]))
                self.summary_writer.add_scalar('loss/train', loss,
                                               state_dict['step'])
                self.summary_writer.add_scalar('acc/train', acc,
                                               state_dict['step'])
                self.summary_writer.add_scalar('precision/train', precision[1],
                                               state_dict['step'])
                self.summary_writer.add_scalar('recall/train', recall[1],
                                               state_dict['step'])

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device,
                                      state_dict['step'])
                state_dict['step_to_stop'] = stop

                if earlystop & stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader,
                    mininterval=5,
                    desc='  - (Evaluation)   ',
                    leave=False):  # training_data should be a iterable

                # get data from dataloader
                feature_1, feature_2, y = parse_data(batch, device)

                batch_size = len(feature_1)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(
                    pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(),
                          recall[1].item())
                '''append the results to the predict / real list for drawing ROC or PR curve.'''
                if plot:
                    pred_list += pred.tolist()
                    real_list += y.tolist()

            if plot:
                area, precisions, recalls, thresholds = pr(
                    pred_list, real_list)
                plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.eval_logger.info(
                '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                % (step, loss_avg, acc_avg, pre_avg, rec_avg))
            self.summary_writer.add_scalar('loss/eval', loss_avg, step)
            self.summary_writer.add_scalar('acc/eval', acc_avg, step)
            self.summary_writer.add_scalar('precision/eval', pre_avg, step)
            self.summary_writer.add_scalar('recall/eval', rec_avg, step)

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = self.checkpoint(step)
                self.save_model(
                    checkpoint,
                    self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg))

            return state_dict['break']

    def train(self,
              max_epoch,
              train_dataloader,
              eval_dataloader,
              device,
              smoothing=False,
              earlystop=False,
              save_mode='best'):

        assert save_mode in ['all', 'best']
        # train for n epoch
        for epoch_i in range(max_epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader,
                                          device, smoothing, earlystop)

            # if state_dict['step_to_stop']:
            #     break

        checkpoint = self.checkpoint(state_dict['step'])

        self.save_model(checkpoint,
                        self.save_path + '-step-%d' % state_dict['step'])

        self.train_logger.info(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.'
            %
            (state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self,
                        data_loader,
                        device,
                        max_batches=None,
                        activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(data_loader,
                              desc='  - (Testing)   ',
                              leave=False):

                feature_1, feature_2, y = parse_data(batch, device)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits.softmax(dim=-1)
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list
Ejemplo n.º 7
0
def train_stage_three(dataset, best_model_file, model_file):
    bestaccuracy = 0.9
    device = 'cudo:0' if torch.cuda.is_available() else 'cpu'
    net = MyUNet(3, device).to(device)
    net.train()
    for parameter in net.parameters():
        if len(parameter.shape) > 1:
            torch.nn.init.xavier_uniform_(parameter)
    if isfile(best_model_file):
        net.load_state_dict(torch.load(best_model_file))
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
    optimizer = AdamW(net.parameters(), lr=0.0001)
    scheduler = CyclicLR(optimizer,
                         0.000001,
                         0.0001,
                         step_size_up=200,
                         mode='triangular2',
                         cycle_momentum=False,
                         last_epoch=-1)
    L1 = torch.nn.L1Loss(size_average=False)

    for epoch in range(50):
        for (images, targets, out_masks) in tqdm(train_loader):
            images = images.to(device)
            targets = targets.to(device)
            out_masks = out_masks.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = L1(outputs * out_masks, targets * out_masks) / 4
            outputs = (outputs * out_masks).cpu().detach().numpy()
            targets = (targets * out_masks).cpu().detach().numpy()
            if np.mean(np.linalg.norm(targets * 100, axis=1)) > 0:
                truth_norm = np.linalg.norm(targets * 100, axis=1).flatten()
                error_norm = np.linalg.norm(outputs * 100 - targets * 100,
                                            axis=1).flatten()
                truth_norm, error_norm = truth_norm[
                    truth_norm > 0], error_norm[error_norm > 0]
                accuracy = sum(
                    (error_norm / truth_norm) < 0.1) / len(error_norm)
                print('mean error',
                      np.mean(error_norm / truth_norm),
                      'accuracy',
                      accuracy,
                      end='\t')
            else:
                accuracy = 0.0
            print('L1loss', loss.cpu().detach().numpy(), end='\r')
            if accuracy > bestaccuracy:
                bestaccuracy = accuracy
                torch.save(net.state_dict(), best_model_file)
            else:
                pass
                # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n')
            loss.backward()
            optimizer.step()
            scheduler.step(None)
            # if idx%5==0:
            #    print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n')
            # idx+=1
        torch.save(net.state_dict(), model_file)
        print(epoch)
Ejemplo n.º 8
0
def train():
    global writer
    # For parsing commandline arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        type=str,
        required=True,
        help='path to dataset folder containing train-test-validation folders')
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        required=True,
                        help='path to folder for saving checkpoints')
    parser.add_argument("--checkpoint",
                        type=str,
                        help='path of checkpoint for pretrained model')
    parser.add_argument(
        "--train_continue",
        type=bool,
        default=False,
        help=
        'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.'
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        help='number of epochs to train. Default: 200.')
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=3,
                        help='batch size for training. Default: 6.')
    parser.add_argument("--validation_batch_size",
                        type=int,
                        default=6,
                        help='batch size for validation. Default: 10.')
    parser.add_argument("--init_learning_rate",
                        type=float,
                        default=0.0001,
                        help='set initial learning rate. Default: 0.0001.')
    parser.add_argument(
        "--milestones",
        type=list,
        default=[25, 50],
        help=
        'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]'
    )
    parser.add_argument(
        "--progress_iter",
        type=int,
        default=200,
        help=
        'frequency of reporting progress and validation. N: after every N iterations. Default: 100.'
    )
    parser.add_argument(
        "--checkpoint_epoch",
        type=int,
        default=5,
        help=
        'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.'
    )
    args = parser.parse_args()

    ##[TensorboardX](https://github.com/lanpa/tensorboardX)
    ### For visualizing loss and interpolated frames

    ###Initialize flow computation and arbitrary-time flow interpolation CNNs.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)

    ###Initialze backward warpers for train and validation datasets

    train_W_dim = 352
    train_H_dim = 352

    trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device)
    trainFlowBackWarp = trainFlowBackWarp.to(device)
    validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim,
                                            device)
    validationFlowBackWarp = validationFlowBackWarp.to(device)

    ###Load Datasets

    # Channel wise mean calculated on custom training dataset
    # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994]
    mean = [0.5] * 3
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train',
                                     randomCropSize=(train_W_dim, train_H_dim),
                                     transform=transform,
                                     train=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    validationset = dataloader.SuperSloMo(
        root=args.dataset_root + '/validation',
        transform=transform,
        randomCropSize=(2 * train_W_dim, train_H_dim),
        train=False)
    validationloader = torch.utils.data.DataLoader(
        validationset,
        batch_size=args.validation_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    print(trainset, validationset)

    ###Create transform to display image from tensor

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)
    TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    ###Utils

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    ###Loss and Optimizer

    L1_lossFn = nn.L1Loss()
    MSE_LossFn = nn.MSELoss()

    if args.train_continue:
        dict1 = torch.load(args.checkpoint)
        last_epoch = dict1['epoch'] * len(trainloader)
    else:
        last_epoch = -1

    params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())

    optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True)
    # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True)

    # scheduler to decrease learning rate by a factor of 10 at milestones.
    # Patience suggested value:
    # patience = number of item in train dataset / train_batch_size * (Number of epochs patience)
    # It does say epoch, but in this case, the number of progress iterations is what's really being worked with.
    # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count)
    # If the model seems to equalize fast, reduce the number of epochs accordingly.

    # scheduler = optim.lr_scheduler.CyclicLR(optimizer,
    #                                         base_lr=1e-8,
    #                                         max_lr=9.0e-3,
    #                                         step_size_up=3500,
    #                                         mode='triangular2',
    #                                         cycle_momentum=False,
    #                                         last_epoch=last_epoch)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=len(trainloader) * 3,
        cooldown=len(trainloader) * 2,
        verbose=True,
        min_lr=1e-8)

    # Changed to use this to ensure a more adaptive model.
    # The changed model used here seems to converge or plateau faster with more rapid swings over time.
    # As such letting the model deal with stagnation more proactively than at a set stage seems more useful.

    ###Initializing VGG16 model for perceptual loss

    vgg16 = torchvision.models.vgg16(pretrained=True)
    vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
    vgg16_conv_4_3.to(device)

    for param in vgg16_conv_4_3.parameters():
        param.requires_grad = False

    # Validation function

    def validate():
        # For details see training.
        psnr = 0
        tloss = 0
        flag = 1
        with torch.no_grad():
            for validationIndex, (validationData,
                                  validationFrameIndex) in enumerate(
                                      validationloader, 0):
                frame0, frameT, frame1 = validationData

                I0 = frame0.to(device)
                I1 = frame1.to(device)
                IFrame = frameT.to(device)

                torch.cuda.empty_cache()
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = model.getFlowCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
                torch.cuda.empty_cache()
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0
                # torch.cuda.empty_cache()
                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = model.getWarpCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # For tensorboard
                if (flag):
                    retImg = torchvision.utils.make_grid([
                        revNormalize(frame0[0]),
                        revNormalize(frameT[0]),
                        revNormalize(Ft_p.cpu()[0]),
                        revNormalize(frame1[0])
                    ],
                                                         padding=10)
                    flag = 0

                # loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                # torch.cuda.empty_cache()
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))

                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)
                torch.cuda.empty_cache()
                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                # torch.cuda.empty_cache()
                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
                tloss += loss.item()

                # psnr
                MSE_val = MSE_LossFn(Ft_p, IFrame)
                psnr += (10 * log10(1 / MSE_val.item()))
                torch.cuda.empty_cache()

        return (psnr / len(validationloader)), (tloss /
                                                len(validationloader)), retImg

    ### Initialization

    if args.train_continue:
        ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        flowComp.load_state_dict(dict1['state_dictFC'])

        optimizer.load_state_dict(dict1.get('state_optimizer', {}))
        scheduler.load_state_dict(dict1.get('state_scheduler', {}))

        for param_group in optimizer.param_groups:
            param_group['lr'] = dict1.get('learningRate',
                                          args.init_learning_rate)

    else:
        dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    ### Training

    import time

    start = time.time()
    cLoss = dict1['loss']
    valLoss = dict1['valLoss']
    valPSNR = dict1['valPSNR']
    checkpoint_counter = 0

    ### Main training loop

    optimizer.step()

    for epoch in range(dict1['epoch'] + 1, args.epochs):
        print("Epoch: ", epoch)

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        iLoss = 0

        for trainIndex, (trainData,
                         trainFrameIndex) in enumerate(trainloader, 0):

            ## Getting the input and the target from the training set
            frame0, frameT, frame1 = trainData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)
            optimizer.zero_grad()
            # torch.cuda.empty_cache()
            # Calculate flow between reference frames I0 and I1
            flowOut = flowComp(torch.cat((I0, I1), dim=1))

            # Extracting flows between I0 and I1 - F_0_1 and F_1_0
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = model.getFlowCoeff(trainFrameIndex, device)

            # Calculate intermediate flows
            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)
            torch.cuda.empty_cache()
            # Calculate optical flow residuals and visibility maps
            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            # Extract optical flow residuals and visibility maps
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0
            # torch.cuda.empty_cache()
            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)
            # torch.cuda.empty_cache()
            wCoeff = model.getWarpCoeff(trainFrameIndex, device)
            torch.cuda.empty_cache()
            # Calculate final intermediate frame
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            # torch.cuda.empty_cache()

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            # torch.cuda.empty_cache()
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        trainFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            # torch.cuda.empty_cache()
            # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4
            # since the loss in paper is calculated for input pixels in range 0-255
            # and the input to our network is in range 0-1
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # Backpropagate

            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())

            iLoss += loss.item()
            torch.cuda.empty_cache()
            # Validation and progress every `args.progress_iter` iterations
            if ((trainIndex % args.progress_iter) == args.progress_iter - 1):
                # Increment scheduler count
                scheduler.step(iLoss / args.progress_iter)

                end = time.time()

                psnr, vLoss, valImg = validate()
                optimizer.zero_grad()
                # torch.cuda.empty_cache()
                valPSNR[epoch].append(psnr)
                valLoss[epoch].append(vLoss)

                # Tensorboard
                itr = trainIndex + epoch * (len(trainloader))

                writer.add_scalars(
                    'Loss', {
                        'trainLoss': iLoss / args.progress_iter,
                        'validationLoss': vLoss
                    }, itr)
                writer.add_scalar('PSNR', psnr, itr)

                writer.add_image('Validation', valImg, itr)
                #####

                endVal = time.time()

                print(
                    " Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %.1e"
                    % (iLoss / args.progress_iter, trainIndex,
                       len(trainloader), end - start, vLoss, psnr,
                       endVal - end, get_lr(optimizer)))

                # torch.cuda.empty_cache()
                cLoss[epoch].append(iLoss / args.progress_iter)
                iLoss = 0
                start = time.time()

        # Create checkpoint after every `args.checkpoint_epoch` epochs
        if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1:
            dict1 = {
                'Detail': "End to end Super SloMo.",
                'epoch': epoch,
                'timestamp': datetime.datetime.now(),
                'trainBatchSz': args.train_batch_size,
                'validationBatchSz': args.validation_batch_size,
                'learningRate': get_lr(optimizer),
                'loss': cLoss,
                'valLoss': valLoss,
                'valPSNR': valPSNR,
                'state_dictFC': flowComp.state_dict(),
                'state_dictAT': ArbTimeFlowIntrp.state_dict(),
                'state_optimizer': optimizer.state_dict(),
                'state_scheduler': scheduler.state_dict()
            }
            torch.save(
                dict1, args.checkpoint_dir + "/SuperSloMo" +
                str(checkpoint_counter) + ".ckpt")
            checkpoint_counter += 1
Ejemplo n.º 9
0
def main():
    def evaluate():
        model.eval()
        num_correct = 0
        num_total = 0
        for inputs, labels in tqdm(test_dl, total=len(test_dl)):
            inputs = inputs.to(device)
            labels = torch.tensor([x % 10 for x in labels.tolist()])
            labels = labels.to(device)
            scores = model(inputs, None)
            num_correct += (scores.max(1)[1] == labels).float().sum().item()
            num_total += scores.size(0)
        logging.info(f'{num_correct / num_total}')
        ws.increment_model(model, num_correct / num_total / 100)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--model', type=str, choices=model_names(), default='las'),
        opt('--workspace',
            type=str,
            default=str(Path('workspaces') / 'default')),
        opt('--load-weights', action='store_true'))
    args = apb.parser.parse_args()

    ws = Workspace(Path(args.workspace))
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)

    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset1 = torchvision.datasets.CIFAR10(root='./data',
                                             train=True,
                                             download=True,
                                             transform=transform)
    testset1 = torchvision.datasets.CIFAR10(root='./data',
                                            train=False,
                                            download=True,
                                            transform=test_transform)

    trainset2 = torchvision.datasets.CIFAR100(root='./data',
                                              train=True,
                                              download=True,
                                              transform=transform)
    testset2 = torchvision.datasets.CIFAR100(root='./data',
                                             train=False,
                                             download=True,
                                             transform=test_transform)

    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )), expand
    ])
    test_transform = transforms.Compose([
        transforms.Pad((2, 2)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, )), expand
    ])
    trainset3 = torchvision.datasets.FashionMNIST(root='./data',
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
    testset3 = torchvision.datasets.FashionMNIST(root='./data',
                                                 train=False,
                                                 download=True,
                                                 transform=test_transform)
    train_dl = tud.DataLoader(tud.ConcatDataset(
        [trainset1, trainset2, trainset3]),
                              batch_size=SETTINGS.training.batch_size,
                              shuffle=True)
    test_dl = tud.DataLoader(tud.ConcatDataset([testset1, testset2, testset3]),
                             batch_size=SETTINGS.training.batch_size,
                             shuffle=False)

    device = torch.device(SETTINGS.training.device)
    model = find_model(args.model)().to(device)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params,
                      SETTINGS.training.learning_rate,
                      weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')
    criterion = nn.CrossEntropyLoss()

    ws.write_args(args)
    ws.write_setting(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs,
                            position=0,
                            leave=True):
        model.train()
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        for inputs, labels in pbar:
            optimizer.zero_grad()
            model.zero_grad()
            labels = torch.tensor([x % 10 for x in labels.tolist()])
            inputs = inputs.to(device)
            labels = labels.to(device)
            scores = model(inputs, None)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            writer.add_scalar('Training/Loss', loss.item(), epoch_idx)

        for group in optimizer.param_groups:
            group['lr'] *= 0.9
        evaluate()
Ejemplo n.º 10
0
            pred_id = preds.argmax(2)
            acc = (pred_id == target).float().mean().cpu()
            running_acc = running_acc * .95 + acc * 0.05 if running_acc is not None else acc
            if running_acc > .1:
                increase_sum_len += 1
                if increase_sum_len == 10:
                    adaptive_summary_len += 1
            else:
                increase_sum_len = 0

            # Calculate gradients
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           Config.max_grad_norm)

            optimizer.step()

            # Warmup scheduler
            scheduler.step()

            # Update progress bar
            progress_bar.update(loss=loss.item())
            progress_bar.progress()

            # Show attention plot
            if progress_bar.count % 200 == 0:
                evaluate_and_show_attention(model,
                                            test_text,
                                            tokenizer,
                                            iteration=epoch +
                                            progress_bar.count / epoch_steps)
    def train(self, train_source, train_target, dev_source, dev_target):
        if os.path.exists(self.args.output_dir) is True:
            shutil.rmtree(self.args.output_dir)

        train_dataloader = create_batch_iter(mode='train', X=train_source, y=train_target, batch_size=self.args.BATCH)
        dev_dataloader = create_batch_iter(mode='dev', X=dev_source, y=dev_target, batch_size=self.args.BATCH)

        self.model.to(DEVICE)

        # 优化器准备
        param_optimizer = list(self.model.named_parameters())
        no_decay = list(['bias', 'LayerNorm.bias', 'LayerNorm.weight'])
        optimizer_grouped_parameters = list([{'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}])

        optimizer = AdamW(params=optimizer_grouped_parameters, lr=self.args.learning_rate)

        total_size = math.ceil(len(train_source) / self.args.BATCH)

        best_acc = 0
        for epoch in range(self.args.EPOCHS):
            for train_step, train_batch in enumerate(tqdm(train_dataloader, desc='Train_Iteration')):
                self.model.train()
                self.model.zero_grad()

                train_batch = tuple(t.to(DEVICE) for t in train_batch)
                t_input_ids, t_input_mask, t_labels, t_out_masks = train_batch

                t_bert_encode = self.model(t_input_ids, t_input_mask)
                loss = self.model.loss_fn(bert_encode=t_bert_encode, tags=t_labels, output_mask=t_out_masks)
                loss.backward()

                # 梯度裁剪
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                optimizer.step()

                if train_step % 10 == 0:
                    self.model.eval()
                    eval_loss = 0

                    for dev_step, dev_batch in enumerate(dev_dataloader):
                        dev_batch = tuple(t.to(DEVICE) for t in dev_batch)
                        d_input_ids, d_input_mask, d_label_ids, d_output_mask = dev_batch

                        with torch.no_grad():
                            d_bert_encode = self.model(d_input_ids, d_input_mask)
                        eval_loss += self.model.loss_fn(bert_encode=d_bert_encode, tags=d_label_ids,
                                                        output_mask=d_output_mask)
                        predicts = self.model.predict(d_bert_encode, d_output_mask)

                        d_label_ids = d_label_ids.view(1, -1)
                        d_label_ids = d_label_ids[d_label_ids != -1]

                        eval_acc, eval_f1 = self.model.acc_f1(predicts, d_label_ids)

                        if eval_acc > best_acc:
                            best_acc = eval_acc
                            save_model(self.model, self.args.output_dir)

                        self.model.class_report(predicts, d_label_ids)

                    logger.info("\n>step {}".format(train_step))
                    logger.info("\n>epoch [{}] {}/{}\n\tloss {:.2f}".format(epoch, train_step, total_size, loss.item()))
        if self.args.output_dir is False:
            save_model(self.model, self.args.output_dir)
Ejemplo n.º 12
0
class Trainer:
    def __init__(
        self,
        config: TrainConfig,
        model: NSMCModel,
        train_data_loader: DataLoader,
        dev_data_loader: DataLoader,
        test_data_loader: DataLoader,
        logger: Logger,
        summary_writer: SummaryWriter,
    ):
        self.config = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.model = model
        self.model.to(self.device)

        self.train_data_loader = train_data_loader
        self.dev_data_loader = dev_data_loader
        self.test_data_loader = test_data_loader
        self.logger = logger
        self.summary_writer = summary_writer

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate)

        # total step 계산
        self.steps_per_epoch = len(train_data_loader)
        self.total_steps = self.steps_per_epoch * config.num_epochs
        self.warmup_steps = config.warmup_step_ratio * self.total_steps

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps)
        self.global_step = 0

    def train(self):
        # train
        self.logger.info("========== train ==========")
        self.logger.info(f"device                : {self.device}")
        self.logger.info(
            f"dataset length/ train : {len(self.train_data_loader.dataset)}")
        self.logger.info(
            f"dataset length/ dev   : {len(self.dev_data_loader.dataset)}")
        self.logger.info(
            f"dataset length/ test  : {len(self.test_data_loader.dataset)}")
        self.logger.info(f"batch size            : {self.config.batch_size}")
        self.logger.info(
            f"learning rate         : {self.config.learning_rate}")
        self.logger.info(f"dropout prob          : {self.config.dropout_prob}")
        self.logger.info(f"total epoch           : {self.config.num_epochs}")
        self.logger.info(f"steps per epoch       : {self.steps_per_epoch}")
        self.logger.info(f"total steps           : {self.total_steps}")
        self.logger.info(f"warmup steps          : {self.warmup_steps}\n")

        for epoch in range(self.config.num_epochs):
            running_loss = 0.0
            train_targets = []
            train_predictions = []

            for step, data in enumerate(tqdm(self.train_data_loader)):
                self.model.train()

                self.global_step += 1

                input_token_ids = data[0].to(self.device)
                attention_mask = data[1].to(self.device)
                token_type_ids = data[2].to(self.device)
                labels = data[3].to(self.device)

                loss, outputs = self._train_step(input_token_ids,
                                                 attention_mask,
                                                 token_type_ids, labels)

                running_loss += loss
                train_targets.extend(labels.tolist())
                train_predictions.extend(outputs.argmax(-1).tolist())

                if (step + 1) % self.config.logging_interval == 0:
                    train_loss = running_loss / self.config.logging_interval
                    train_acc = accuracy_score(train_targets,
                                               train_predictions)
                    self.logger.info(
                        f"Epoch {epoch}, Step {step + 1}\t| Loss {train_loss:.4f}  Acc {train_acc:.4f}"
                    )

                    self.summary_writer.add_scalar("nsmc/train/loss",
                                                   train_loss,
                                                   self.global_step)
                    self.summary_writer.add_scalar("nsmc/train/accuracy",
                                                   train_acc, self.global_step)

                    running_loss = 0.0
                    train_targets = []
                    train_predictions = []

            # dev every epoch
            dev_loss, dev_targets, dev_predictions = self._validation(
                self.dev_data_loader)
            dev_report = classification_report(dev_targets,
                                               dev_predictions,
                                               digits=4)
            self.logger.info(f"######### DEV REPORT #EP{epoch} #########")
            self.logger.info(f"Loss {dev_loss:.4f}")
            self.logger.info(f"\n{dev_report}")

            dev_acc = accuracy_score(dev_targets, dev_predictions)
            self.summary_writer.add_scalar("nsmc/dev/loss", dev_loss,
                                           self.global_step)
            self.summary_writer.add_scalar("nsmc/dev/accuracy", dev_acc,
                                           self.global_step)

            # test every epoch
            test_loss, test_targets, test_predictions = self._validation(
                self.test_data_loader)
            test_report = classification_report(test_targets,
                                                test_predictions,
                                                digits=4)
            self.logger.info(f"######### TEST REPORT #EP{epoch} #########")
            self.logger.info(f"Loss {test_loss:.4f}")
            self.logger.info(f"\n{test_report}")

            test_acc = accuracy_score(test_targets, test_predictions)
            self.summary_writer.add_scalar("nsmc/test/loss", test_loss,
                                           self.global_step)
            self.summary_writer.add_scalar("nsmc/test/accuracy", test_acc,
                                           self.global_step)

            # output_path = os.path.join(self.config.checkpoint_dir, f"model-epoch-{epoch}.pth")
            # torch.save(self.model.state_dict(), output_path)
            # self.logger.info(f"MODEL IS SAVED AT {output_path}\n")

    def _train_step(self, input_token_ids, attention_mask, token_type_ids,
                    labels):
        self.optimizer.zero_grad()

        outputs = self.model(input_token_ids, attention_mask, token_type_ids)

        loss = self.criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

        self.optimizer.step()
        self.scheduler.step()

        return loss.item(), outputs

    def _validation(self, data_loader):
        self.model.eval()

        running_loss = 0.0
        targets = []
        predictions = []

        with torch.no_grad():
            for data in data_loader:
                input_token_ids = data[0].to(self.device)
                attention_mask = data[1].to(self.device)
                token_type_ids = data[2].to(self.device)
                labels = data[3].to(self.device)

                outputs = self.model(input_token_ids, attention_mask,
                                     token_type_ids)

                loss = self.criterion(outputs, labels)

                running_loss += loss.item()
                targets.extend(labels.tolist())
                predictions.extend(outputs.argmax(-1).tolist())

        assert len(targets) == len(predictions)

        mean_loss = running_loss / len(data_loader)

        return mean_loss, targets, predictions
class TransformingAutoEncoderController(nn.Module, ControllableModel):
    '''
    Transforming Auto-Encoder.

    References:
        - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
        - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694.
        - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126.
        - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
        - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
        - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.

    '''
    __loaded_filename = None
    __loaded_ctx = None

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss=None,
        optimizer_f=None,
        encoder=None,
        decoder=None,
        reconstructor=None,
        layer_n=3,
        head_n=3,
        seq_len=5,
        depth_dim=100,
        hidden_dim=100,
        self_attention_activation_list=[],
        multi_head_attention_activation_list=[],
        fc_activation_list=[],
        learning_rate=1e-05,
        weight_decay=0.01,
        dropout_rate=0.5,
        ctx="cpu",
        regularizatable_data_list=[],
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `gluon.loss`.
            encoder:                        is-a `TransformerModel`.
            decoder:                        is-a `TransformerModel`.
            reconstructor:                  is-a `TransformerModel`.
            layer_n:                        `int` of the number of layers.
            head_n:                         `int` of the number of heads for multi-head attention model.
            seq_len:                        `int` of the length of sequences.
            depth_dim:                      `int` of dimension of dense layer.
            hidden_dim:                     `int` of dimension of hidden(encoder) layer.
            self_attention_activation_list: `list` of `str` of activation function for self-attention model.
            multi_head_attention_activation_list:   `list` of `str` of activation function for multi-head attention model.
            fc_activation_list:             `list` of `str` of activation function in fully-connected layers.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
            optimizer_name:                 `str` of name of optimizer.
            hybridize_flag:                  Call `mxnet.gluon.HybridBlock.hybridize()` or not.
            scale:                          `float` of scaling factor for initial parameters.
            ctx:                            `mx.cpu()` or `mx.gpu()`.
            initializer:                    is-a `mxnet.initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.

        '''
        super(TransformingAutoEncoderController, self).__init__()

        if computable_loss is None:
            computable_loss = nn.CrossEntropyLoss()
        self.__computable_loss = computable_loss

        if encoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                encoder = TransformerEncoder(
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            else:
                encoder = TransformerEncoder(
                    depth_dim=hidden_dim,
                    layer_n=layer_n,
                    head_n=head_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    dropout_rate=dropout_rate,
                    ctx=ctx)
            encoder.embedding_flag = False
        else:
            if isinstance(encoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `encoder` must be `TransformerModel`.")

        if decoder is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            else:
                decoder = TransformerDecoder(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=hidden_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    multi_head_attention_activation_list=
                    multi_head_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx)
            decoder.embedding_flag = False
        else:
            if isinstance(decoder, TransformerModel) is False:
                raise TypeError(
                    "The type of `decoder` must be `TransformerModel`.")

        if reconstructor is None:
            if hidden_dim is None or hidden_dim == depth_dim:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            else:
                reconstructor = TransformerReconstructor(
                    head_n=head_n,
                    depth_dim=hidden_dim,
                    output_dim=depth_dim,
                    layer_n=layer_n,
                    self_attention_activation_list=
                    self_attention_activation_list,
                    fc_activation_list=fc_activation_list,
                    computable_loss=computable_loss,
                    not_init_flag=not_init_flag,
                    ctx=ctx,
                )
            reconstructor.embedding_flag = False
        else:
            if isinstance(reconstructor, TransformerModel) is False:
                raise TypeError(
                    "The type of `reconstructor` must be `TransformerModel`.")

        logger = getLogger("accelbrainbase")
        self.logger = logger

        self.encoder = encoder
        self.decoder = decoder
        self.reconstructor = reconstructor

        if hidden_dim is not None and hidden_dim != depth_dim:
            self.encoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            self.decoder_hidden_fc = nn.Linear(
                depth_dim,
                hidden_dim,
                bias=True,
            )
            init_flag = True
        else:
            self.encoder_hidden_fc = None
            self.decoder_hidden_fc = None
            init_flag = False

        self.__ctx = ctx
        self.to(self.__ctx)

        if self.init_deferred_flag is False:
            if not_init_flag is False:
                if optimizer_f is not None:
                    if init_flag is True:
                        self.optimizer = optimizer_f(self.parameters(), )
                else:
                    if init_flag is True:
                        self.optimizer = AdamW(self.parameters(),
                                               lr=learning_rate,
                                               weight_decay=weight_decay)

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__learning_rate = learning_rate
        self.__weight_decay = weight_decay
        self.seq_len = seq_len
        self.epoch = 0

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `TransformerIterator`.
        '''
        if isinstance(iteratable_data, TransformerIterator) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `TransformerIterator`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate

        try:
            epoch = self.epoch
            iter_n = 0
            for encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr, test_encoded_observed_arr, test_decoded_observed_arr, test_encoded_mask_arr, test_decoded_mask_arr, training_target_arr, test_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.epoch = epoch
                if self.encoder.optimizer is not None and self.decoder.optimizer is not None:
                    optimizer_setup_flag = True
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                else:
                    optimizer_setup_flag = False

                pred_arr = self.inference(encoded_observed_arr,
                                          decoded_observed_arr,
                                          encoded_mask_arr, decoded_mask_arr)
                loss = self.compute_loss(pred_arr, training_target_arr)
                if optimizer_setup_flag is False:
                    self.encoder.optimizer.zero_grad()
                    self.decoder.optimizer.zero_grad()
                    self.optimizer.zero_grad()
                    pred_arr = self.inference(encoded_observed_arr,
                                              decoded_observed_arr,
                                              encoded_mask_arr,
                                              decoded_mask_arr)
                    loss = self.compute_loss(pred_arr, training_target_arr)

                loss.backward()
                self.optimizer.step()
                self.decoder.optimizer.step()
                self.encoder.optimizer.step()
                self.regularize()
                self.decoder.regularize()
                self.encoder.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    if torch.inference_mode():
                        test_pred_arr = self.inference(
                            test_encoded_observed_arr,
                            test_decoded_observed_arr, test_encoded_mask_arr,
                            test_decoded_mask_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_target_arr)

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()

                    self.__loss_list.append((_loss, _test_loss))
                    self.logger.debug("Epochs: " + str(epoch + 1) +
                                      " Train loss: " + str(_loss) +
                                      " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.logger.debug("Interrupt.")

        self.logger.debug("end. ")
        self.epoch = epoch

    def inference(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Inference samples drawn by `IteratableData.generate_inferenced_samples()`.

        Args:
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

        Returns:
            `mxnet.ndarray` of inferenced feature points.
        '''
        return self(
            encoded_observed_arr,
            decoded_observed_arr,
            encoded_mask_arr,
            decoded_mask_arr,
        )

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def forward(
        self,
        encoded_observed_arr,
        decoded_observed_arr,
        encoded_mask_arr=None,
        decoded_mask_arr=None,
    ):
        '''
        Hybrid forward with Gluon API.

        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            encoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_observed_arr:   rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            encoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)

            decoded_mask_arr:       rank-3 Array like or sparse matrix as the observed data points.
                                    The shape is: (batch size, the length of sequence, feature points)
        
        Returns:
            `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points.
        '''
        if self.__loaded_filename is not None:
            loaded_filename = self.__loaded_filename
            self.__loaded_filename = None
            init_encoded_observed_arr = encoded_observed_arr.detach()
            init_decoded_observed_arr = decoded_observed_arr.detach()
            if encoded_mask_arr is not None:
                init_encoded_mask_arr = encoded_mask_arr.detach()
            else:
                init_encoded_mask_arr = None
            if decoded_mask_arr is not None:
                init_decoded_mask_arr = decoded_mask_arr.detach()
            else:
                init_decoded_mask_arr = decoded_mask_arr

            _ = self.forward(
                init_encoded_observed_arr,
                init_decoded_observed_arr,
                init_encoded_mask_arr,
                init_decoded_mask_arr,
            )
            self.load_parameters(loaded_filename, ctx=self.__loaded_ctx)
            self.__loaded_ctx = None

        if encoded_mask_arr is None:
            encoded_mask_arr = torch.ones(
                (encoded_observed_arr.shape[0], 1, 1, 1), )
            encoded_mask_arr = encoded_mask_arr.to(encoded_observed_arr.device)
        if decoded_mask_arr is None:
            decoded_mask_arr = torch.ones(
                (decoded_observed_arr.shape[0], 1, 1, 1), )
            decoded_mask_arr = decoded_mask_arr.to(decoded_observed_arr.device)

        if self.encoder_hidden_fc is not None:
            encoded_observed_arr = self.encoder_hidden_fc(encoded_observed_arr)
        if self.decoder_hidden_fc is not None:
            decoded_observed_arr = self.decoder_hidden_fc(decoded_observed_arr)

        encoded_arr = self.encoder(encoded_observed_arr, encoded_mask_arr)
        decoded_arr = self.decoder(
            decoded_observed_arr,
            encoded_arr,
            decoded_mask_arr,
            encoded_mask_arr,
        )

        self.feature_points_arr = decoded_arr

        reconstructed_arr = self.reconstructor(decoded_arr, None)

        return reconstructed_arr

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def __rename_file(self, filename):
        filename_list = filename.split(".")
        _format = filename_list[-1]
        encoder_filename = filename.replace("." + _format,
                                            "_encoder." + _format)
        decoder_filename = filename.replace("." + _format,
                                            "_decoder." + _format)
        reconstructor_filename = filename.replace("." + _format,
                                                  "_reconstructor." + _format)
        return encoder_filename, decoder_filename, reconstructor_filename

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
            filename)

        self.encoder.epoch = self.epoch
        self.encoder.loss_arr = self.loss_arr
        self.encoder.save_parameters(encoder_filename)

        self.decoder.epoch = self.epoch
        self.decoder.loss_arr = self.loss_arr
        self.decoder.save_parameters(decoder_filename)

        self.reconstructor.epoch = self.epoch
        self.reconstructor.loss_arr = self.loss_arr
        self.reconstructor.save_parameters(decoder_filename)

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'loss': self.loss_arr,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        try:
            encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file(
                filename)
            self.encoder.load_parameters(encoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.decoder.load_parameters(decoder_filename,
                                         ctx=ctx,
                                         strict=strict)
            self.reconstructor.load_parameters(reconstructor_filename,
                                               ctx=ctx,
                                               strict=strict)

            checkpoint = torch.load(filename)
            self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epoch = checkpoint['epoch']
            self.__loss_list = checkpoint['loss'].tolist()
        except RuntimeError:
            self.__loaded_filename = filename
            self.__loaded_ctx = ctx

        if ctx is not None:
            self.to(ctx)
            self.encoder.to(ctx)
            self.decoder.to(ctx)
            self.reconstructor.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not.'''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not.'''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)

    __loss_list = []

    def get_loss_arr(self):
        ''' getter '''
        return np.array(self.__loss_list)

    def set_loss_arr(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    loss_arr = property(get_loss_arr, set_loss_arr)
Ejemplo n.º 14
0
class Distiller:
    def __init__(self, params: dict, dataset: LmSeqsDataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths,
                                           k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler,
                                          group_ids=groups,
                                          batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler,
                                   batch_size=params.batch_size,
                                   drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info("Using MLM loss for LM step.")
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor(
                [params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info("Using CLM loss for LM step.")

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps *
                params.n_epoch) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel

                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )

        self.is_master = params.is_master
        if self.is_master:
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training",
                                      text_string=str(self.params),
                                      global_step=0)
            self.tensorboard.add_text(tag="config/student",
                                      text_string=str(self.student_config),
                                      global_step=0)

    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1),
                                 dtype=torch.long,
                                 device=lengths.device) < lengths[:, None]

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids["mask_token"])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = (_token_ids_mask * (probs == 0).long() + _token_ids_real *
                      (probs == 1).long() + _token_ids_rand *
                      (probs == 2).long())
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1),
                                 dtype=torch.long,
                                 device=lengths.device) < lengths[:, None]
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[
            ~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids["pad_token"]
            else:
                pad_id = self.params.special_tok_ids["unk_token"]
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master:
            logger.info("Starting training")
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f"--- Starting epoch {self.epoch}/{self.params.n_epoch - 1}"
                )
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f"cuda:{self.params.local_rank}") for t in batch)

                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(
                        batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(
                        batch=batch)
                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          lm_labels=lm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    "Last_loss":
                    f"{self.last_loss:.2f}",
                    "Avg_cum_loss":
                    f"{self.total_loss_epoch / self.n_iter:.2f}"
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f"--- Ending epoch {self.epoch}/{self.params.n_epoch - 1}")
            self.end_epoch()

        if self.is_master:
            logger.info("Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name="pytorch_model.bin")
            logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             lm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        if self.mlm:
            s_logits, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=attention_mask)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
        else:
            s_logits, _, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=None)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, _, t_hidden_states = self.teacher(
                    input_ids=input_ids,
                    attention_mask=None)  # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        # (self.temperature) ** 2 的解释 https://zhuanlan.zhihu.com/p/102038521
        loss_ce = (self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature, dim=-1),
        ) * (self.temperature)**2)
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.0:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                        lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.0:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_hidden_states)  # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)

            s_hidden_states_slct = torch.masked_select(
                s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(
                t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(
                s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.0:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.0:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag="parameter_mean/" + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="parameter_std/" + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_clm > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_clm",
                                        scalar_value=self.last_loss_clm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f"{self.n_sequences_epoch} sequences have been trained during this epoch."
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f"model_epoch_{self.epoch}.pth")
            self.tensorboard.add_scalar(tag="epoch/loss",
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, "module") else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
Ejemplo n.º 15
0
def main():
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        engine = InferenceEngine(model,
                                 zmuv_transform,
                                 negative_label=num_labels - 1)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = False
            curr_time = 0
            for window in stride(
                    audio_data,
                    SETTINGS.training.max_window_size_seconds * 1000,
                    SETTINGS.training.eval_stride_size_seconds * 1000,
                    SETTINGS.audio.sample_rate):
                if window.size(-1) < 1000:
                    break
                pred = engine.infer(window.squeeze(0), curr_time=curr_time)
                engine.append_label(pred, curr_time=curr_time)
                seq_present = seq_present or engine.sequence_present(
                    curr_time=curr_time)
                curr_time += SETTINGS.training.eval_stride_size_seconds * 1000
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(
                        f'{ex.metadata.transcription.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n'
                    )
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}',
                                  c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)

    def do_evaluate():
        evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True)
        evaluate_engine(ww_dev_pos_ds,
                        'Dev noisy positive',
                        positive_set=True,
                        mixer=dev_mixer)
        evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False)
        evaluate_engine(ww_dev_neg_ds,
                        'Dev noisy negative',
                        positive_set=False,
                        mixer=dev_mixer)
        evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True)
        evaluate_engine(ww_test_pos_ds,
                        'Test noisy positive',
                        positive_set=True,
                        mixer=test_mixer)
        evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False)
        evaluate_engine(ww_test_neg_ds,
                        'Test noisy negative',
                        positive_set=False,
                        mixer=test_mixer)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--model', type=str, choices=model_names(), default='las'),
        opt('--workspace',
            type=str,
            default=str(Path('workspaces') / 'default')),
        opt('--load-weights', action='store_true'),
        opt('--load-last', action='store_true'),
        opt('--vocab', type=str, nargs='+', default=[' hey', 'fire fox']),
        opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    num_labels = len(args.vocab) + 1

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    ww = SETTINGS.training.wake_word
    logging.info(f'Using {ww}')
    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate,
                     mono=SETTINGS.audio.use_mono,
                     words=args.vocab)
    ww_train_ds, ww_dev_ds, ww_test_ds = loader.load_splits(
        SETTINGS.dataset.dataset_path, **ds_kwargs)
    print_stats('Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    sr = SETTINGS.audio.sample_rate
    wind_sz = int(SETTINGS.training.eval_window_size_seconds * sr)
    stri_sz = int(SETTINGS.training.eval_stride_size_seconds * sr)

    inference_wakeword = InferenceEngineSettings().make_wakeword(args.vocab)
    ww_dev_all_pos_ds = ww_dev_ds.filter(
        lambda x: x.compute_frame_labels(args.vocab), clone=True)
    ww_dev_pos_ds = ww_dev_ds.filter(
        lambda x: x.compute_frame_labels([inference_wakeword]), clone=True)
    ww_dev_neg_ds = ww_dev_ds.filter(
        lambda x: not x.compute_frame_labels([inference_wakeword]), clone=True)
    ww_test_pos_ds = ww_test_ds.filter(
        lambda x: x.compute_frame_labels([inference_wakeword]), clone=True)
    ww_test_neg_ds = ww_test_ds.filter(
        lambda x: not x.compute_frame_labels([inference_wakeword]), clone=True)

    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)
    batchifier = WakeWordBatchifier(
        num_labels - 1,
        window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000))
    train_comp = (NoiseTransform().train(), batchifier)

    if SETTINGS.training.use_noise_dataset:
        noise_ds = RecursiveNoiseDatasetLoader().load(
            SETTINGS.raw_dataset.noise_dataset_path,
            sr=SETTINGS.audio.sample_rate,
            mono=SETTINGS.audio.use_mono)
        logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.')
        noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
        noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))
        train_comp = (DatasetMixer(noise_ds_train).train(), ) + train_comp
        dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False)
        test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False)
    train_comp = compose(*train_comp)

    prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds,
                                             collate_fn=batchify).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(ww_train_ds,
                                              collate_fn=train_comp).build(
                                                  SETTINGS.training.batch_size)

    model = find_model(args.model)().to(device)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params,
                      SETTINGS.training.learning_rate,
                      weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')
    criterion = nn.CrossEntropyLoss()

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path /
                                                      'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(
            dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=not args.load_last)
    if args.eval:
        ws.load_model(model, best=not args.load_last)
        do_evaluate()
        return

    ws.write_args(args)
    ws.write_setting(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs,
                            position=0,
                            leave=True):
        model.train()
        std_transform.train()
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        for batch in pbar:
            batch.to(device)
            scores = model(zmuv_transform(std_transform(batch.audio_data)),
                           std_transform.compute_lengths(batch.lengths))
            optimizer.zero_grad()
            model.zero_grad()
            loss = criterion(scores, batch.labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            writer.add_scalar('Training/Loss', loss.item(), epoch_idx)

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay
        evaluate_engine(ww_dev_pos_ds,
                        'Dev positive',
                        positive_set=True,
                        save=True,
                        write_errors=False)
    do_evaluate()
Ejemplo n.º 16
0
class CienaTransformerModel(Model):
    def __init__(
            self, save_path, log_path, d_features, d_meta, max_length, d_classifier, n_classes, threshold=None,
            optimizer=None, **kwargs):
        '''**kwargs: n_layers, n_head, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                save_path -- model file path
                log_path -- log file path
                d_features -- how many PMs
                d_meta -- how many facility types
                max_length -- input sequence length
                d_classifier -- classifier hidden unit
                n_classes -- output dim
                threshold -- if not None, n_classes should be 1 (regression).
        '''

        super().__init__(save_path, log_path)
        self.d_output = n_classes
        self.threshold = threshold
        self.max_length = max_length

        # ----------------------------- Model ------------------------------ #

        self.model = Encoder(TimeFacilityEncoding, d_features=d_features, max_seq_length=max_length,
                                       d_meta=d_meta, **kwargs)

        # --------------------------- Classifier --------------------------- #
        self.classifier = LinearClassifier(d_features * max_length, d_classifier, n_classes)

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)
        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.count_parameters()
        self.set_logger()
        self.set_summary_writer()
        # ---------------------------- END INIT ---------------------------- #

    def checkpoint(self, step):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'classifier_state_dict': self.classifier.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'global_step': step}
        return checkpoint

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing, earlystop):
        ''' Epoch operation in training phase'''

        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(
                train_dataloader, mininterval=1,
                desc='  - (Training)   ', leave=False):  # training_data should be a iterable

            # get data from dataloader

            input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

            batch_size = len(y)

            non_pad_mask, slf_attn_mask = get_attn_mask(padding)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(pred, y, self.d_output,
                                                                            threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (
                        state_dict['step'], loss, acc, precision[1], recall[1]))
                self.summary_writer.add_scalar('loss/train', loss, state_dict['step'])
                self.summary_writer.add_scalar('acc/train', acc, state_dict['step'])
                self.summary_writer.add_scalar('precision/train', precision[1], state_dict['step'])
                self.summary_writer.add_scalar('recall/train', recall[1], state_dict['step'])

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device, state_dict['step'])
                state_dict['step_to_stop'] = stop

                if earlystop & stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader, mininterval=5,
                    desc='  - (Evaluation)   ', leave=False):  # training_data should be a iterable

                input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

                batch_size = len(y)

                non_pad_mask, slf_attn_mask = get_attn_mask(padding)

                # get logits
                logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(), recall[1].item())

                '''append the results to the predict / real list for drawing ROC or PR curve.'''
            #     if plot:
            #         pred_list += pred.tolist()
            #         real_list += y.tolist()
            #
            # if plot:
            #     area, precisions, recalls, thresholds = pr(pred_list, real_list)
            #     plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.eval_logger.info(
                '[EVALUATION] - step: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f' % (
                    step, loss_avg, acc_avg, pre_avg, rec_avg))
            self.summary_writer.add_scalar('loss/eval', loss_avg, step)
            self.summary_writer.add_scalar('acc/eval', acc_avg, step)
            self.summary_writer.add_scalar('precision/eval', pre_avg, step)
            self.summary_writer.add_scalar('recall/eval', rec_avg, step)

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = self.checkpoint(step)
                self.save_model(checkpoint, self.save_path + '-step-%d_loss-%.5f' % (step, loss_avg))

            return state_dict['break']

    def train(self, max_epoch, train_dataloader, eval_dataloader, device,
              smoothing=False, earlystop=False, save_mode='best'):
        assert save_mode in ['all', 'best']

        # train for n epoch
        for epoch_i in range(max_epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader, device, smoothing, earlystop)

            # if state_dict['step_to_stop']:
            #     break

        checkpoint = self.checkpoint(state_dict['step'])

        self.save_model(checkpoint, self.save_path + '-step-%d' % state_dict['step'])

        self.train_logger.info(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.' % (
                state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self, data_loader, device, max_batches=None, activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(
                    data_loader,
                    desc='  - (Testing)   ', leave=False):

                input_feature_sequence, padding, time_facility, y = map(lambda x: x.to(device), batch)

                non_pad_mask, slf_attn_mask = get_attn_mask(padding)

                # get logits
                logits, attn = self.model(input_feature_sequence, time_facility, non_pad_mask, slf_attn_mask)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits.softmax(dim=-1)
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list