Пример #1
0
def main():
    args = parser.parse_args()
    save_folder = args.save_folder

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    if args.visdom:
        from visdom import Visdom
        viz = Visdom()

        opts = [
            dict(title='Loss', ylabel='Loss', xlabel='Epoch'),
            dict(title='WER', ylabel='WER', xlabel='Epoch'),
            dict(title='CER', ylabel='CER', xlabel='Epoch')
        ]

        viz_windows = [None, None, None]
        epochs = torch.arange(1, args.epochs + 1)
    if args.tensorboard:
        from logger import TensorBoardLogger
        try:
            os.makedirs(args.log_dir)
        except OSError as e:
            if e.errno == errno.EEXIST:
                print('Directory already exists.')
                for file in os.listdir(args.log_dir):
                    file_path = os.path.join(args.log_dir, file)
                    try:
                        if os.path.isfile(file_path):
                            os.unlink(file_path)
                    except Exception as e:
                        raise
            else:
                raise
        logger = TensorBoardLogger(args.log_dir)

    try:
        os.makedirs(save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    criterion = CTCLoss()

    with open(args.labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))
    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True,
                                      augment=False)
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=True)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True)
    decoder = GreedyDecoder(labels)

    if args.continue_from:
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from)
        model.load_state_dict(package['state_dict'])
        optimizer.load_state_dict(package['optim_dict'])
        start_epoch = int(package.get(
            'epoch', 1)) - 1  # Python index start at 0 for training
        start_iter = package.get('iteration', None)
        if start_iter is None:
            start_epoch += 1  # Assume that we saved a model after an epoch finished, so start at the next epoch.
            start_iter = 0
        else:
            start_iter += 1
        avg_loss = int(package.get('avg_loss', 0))
        loss_results, cer_results, wer_results = package[
            'loss_results'], package['cer_results'], package['wer_results']
        if args.visdom and \
                        package['loss_results'] is not None and start_epoch > 0:  # Add previous scores to visdom graph
            x_axis = epochs[0:start_epoch]
            y_axis = [
                loss_results[0:start_epoch], wer_results[0:start_epoch],
                cer_results[0:start_epoch]
            ]
            for x in range(len(viz_windows)):
                viz_windows[x] = viz.line(
                    X=x_axis,
                    Y=y_axis[x],
                    opts=opts[x],
                )
        if args.tensorboard and \
                        package['loss_results'] is not None and start_epoch > 0:  # Previous scores to tensorboard logs
            for i in range(start_epoch):
                info = {
                    'Avg Train Loss': loss_results[i],
                    'Avg WER': wer_results[i],
                    'Avg CER': cer_results[i]
                }
                for tag, val in info.items():
                    logger.scalar_summary(tag, val, i + 1)
        if not args.no_bucketing:
            print("Using bucketing sampler for the following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler
    else:
        avg_loss = 0
        start_epoch = 0
        start_iter = 0
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        end = time.time()
        for i, (data) in enumerate(train_loader, start=start_iter):
            if i == len(train_loader):
                break
            inputs, targets, input_percentages, target_sizes = data
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = Variable(inputs, requires_grad=False)
            target_sizes = Variable(target_sizes, requires_grad=False)
            targets = Variable(targets, requires_grad=False)

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH

            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int(),
                             requires_grad=False)

            loss = criterion(out, targets, sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            if args.cuda:
                torch.cuda.synchronize()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (i + 1),
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))
            if args.checkpoint_per_batch > 0 and i > 0 and (
                    i + 1) % args.checkpoint_per_batch == 0:
                file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (
                    save_folder, epoch + 1, i + 1)
                print("Saving checkpoint model to %s" % file_path)
                torch.save(
                    DeepSpeech.serialize(model,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         iteration=i,
                                         loss_results=loss_results,
                                         wer_results=wer_results,
                                         cer_results=cer_results,
                                         avg_loss=avg_loss), file_path)
            del loss
            del out
        avg_loss /= len(train_loader)

        print('Training Summary Epoch: [{0}]\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, loss=avg_loss))

        start_iter = 0  # Reset start iteration for next epoch
        total_cer, total_wer = 0, 0
        model.eval()
        for i, (data) in enumerate(test_loader):  # test
            inputs, targets, input_percentages, target_sizes = data

            inputs = Variable(inputs, volatile=True)

            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH
            seq_length = out.size(0)
            sizes = input_percentages.mul_(int(seq_length)).int()

            decoded_output = decoder.decode(out.data, sizes)
            target_strings = decoder.process_strings(
                decoder.convert_to_strings(split_targets))
            wer, cer = 0, 0
            for x in range(len(target_strings)):
                wer += decoder.wer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x].split()))
                cer += decoder.cer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x]))
            total_cer += cer
            total_wer += wer

            if args.cuda:
                torch.cuda.synchronize()
            del out
        wer = total_wer / len(test_loader.dataset)
        cer = total_cer / len(test_loader.dataset)
        wer *= 100
        cer *= 100
        loss_results[epoch] = avg_loss
        wer_results[epoch] = wer
        cer_results[epoch] = cer
        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if args.visdom:
            # epoch += 1
            x_axis = epochs[0:epoch + 1]
            y_axis = [
                loss_results[0:epoch + 1], wer_results[0:epoch + 1],
                cer_results[0:epoch + 1]
            ]
            for x in range(len(viz_windows)):
                if viz_windows[x] is None:
                    viz_windows[x] = viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        opts=opts[x],
                    )
                else:
                    viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        win=viz_windows[x],
                        update='replace',
                    )
        if args.tensorboard:
            info = {'Avg Train Loss': avg_loss, 'Avg WER': wer, 'Avg CER': cer}
            for tag, val in info.items():
                logger.scalar_summary(tag, val, epoch + 1)
            if args.log_params:
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    logger.histo_summary(tag, to_np(value), epoch + 1)
                    logger.histo_summary(tag + '/grad', to_np(value.grad),
                                         epoch + 1)
        if args.checkpoint:
            file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results), file_path)
        # anneal lr
        optim_state = optimizer.state_dict()
        optim_state['param_groups'][0][
            'lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
        optimizer.load_state_dict(optim_state)
        print('Learning rate annealed to: {lr:.6f}'.format(
            lr=optim_state['param_groups'][0]['lr']))

        avg_loss = 0
        if not args.no_bucketing and epoch == 0:
            print("Switching to bucketing sampler for following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler

    torch.save(DeepSpeech.serialize(model, optimizer=optimizer),
               args.final_model_path)
            torch.cuda.set_device(int(args.gpu_rank))
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None
    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = model.labels
        audio_conf = model.audio_conf
        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
Пример #3
0
 def __init__(self,
              model,
              dataloader,
              loss_func,
              optimizer,
              scheduler,
              n_epochs,
              acc_func=None,
              train_name='default',
              resume_dict=None,
              ckp_dir=None,
              resume_ep='latest',
              logger=None,
              tb_log_dir=None,
              log_step_interval=100,
              comment=None,
              val_dataloader=None):
     '''
     If resume_dict is not None then load from resume_dict and ignore files stored in ckp_dir
     '''
     self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3]
     self.model = model
     if torch.cuda.device_count() > 1:
         self.model = nn.DataParallel(self.model)
     self.dataloader = dataloader
     self.val_dataloader = val_dataloader
     self.loss_func = loss_func
     self.optimizer = optimizer
     self.scheduler = scheduler
     self.n_epochs = n_epochs
     self.acc_func = acc_func
     self.logger = logger
     self.ckp_dir = ckp_dir
     self.tb_log_dir = tb_log_dir
     self.log_step_interval = log_step_interval
     self.start_epoch = 0
     self.epoch_loss = None
     self.epoch_acc = None
     self.epoch_num = None
     self.step_loss = None
     self.step_acc = None
     self.step_num = None
     self.val_loss = None
     self.val_nn_acc = None
     self.steps_per_epoch = len(self.dataloader)
     if self.ckp_dir is None:
         self.ckp_dir = '/media/zli/Seagate Backup Plus Drive/trained_models/pytorch-py3/checkpoints/{}/{}'.format(
             train_name, self.timestamp)
     if not os.path.isdir(self.ckp_dir):
         os.makedirs(self.ckp_dir)
     # find model
     ckp_dir_file_names = os.listdir(self.ckp_dir)
     if model is None:
         if 'jit_model.pth' in ckp_dir_file_names:
             self.model = torch.jit.load(
                 os.path.join(self.ckp_dir, 'jit_model.pth'))
         else:
             raise ValueError('No model found')
     # find resume ep
     if resume_ep == 'latest':
         epoch_file_list = glob(os.path.join(self.ckp_dir, 'epoch_*.pth'))
         if len(epoch_file_list):
             prog = re.compile("epoch_([0-9]+).pth")
             epoch_number_list = [
                 int(prog.findall(i)[0]) for i in epoch_file_list
             ]
             resume_ep = max(epoch_number_list)
             print(f'Latest epoch {resume_ep} found.')
         else:
             resume_ep = None
     else:
         assert resume_ep.isnumeric(), ValueError(
             "resume_ep must be numeric or 'latest'.")
         assert os.path.isfile(os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth")), \
                FileNotFoundError(f"epoch {resume_ep} not found in {self.ckp_dir}")
         resume_ep = int(resume_ep)
     if resume_dict is not None:
         self.load_epoch(resume_dict)
     elif resume_ep is not None:
         self.load_epoch(
             os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth"))
     # Logger
     if self.logger is None:
         if self.tb_log_dir is None:
             self.tb_log_dir = 'tb_logs/{}'.format(train_name)
         self.tb_log_dir = os.path.join(self.tb_log_dir, self.timestamp)
         self.logger = TensorBoardLogger(self.tb_log_dir)
     self.logger.log_init(self)
     self.record_params = [
         'loss_func', 'step_loss', 'step_acc', 'val_loss', 'val_nn_acc'
     ]
     self.epoch_finish_hook = []
Пример #4
0
        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=args.dist_backend)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    checkpoint_handler = CheckpointHandler(save_folder=args.save_folder,
                                           best_val_model_name=args.best_val_model_name,
                                           checkpoint_per_iteration=args.checkpoint_per_iteration,
                                           save_n_recent_models=args.save_n_recent_models)

    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir, args.log_params)

    if args.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            args.continue_from = latest_checkpoint

    if args.continue_from:  # Starting from previous model
        state = TrainingState.load_state(state_path=args.continue_from)
        model = state.model
        if args.finetune:
            state.init_finetune_states(args.epochs)

        if main_proc and args.visdom:  # Add previous scores to visdom graph
            visdom_logger.load_previous_values(state.epoch, state.results)
        if main_proc and args.tensorboard:  # Previous scores to tensorboard logs
Пример #5
0
    random.seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    main_proc = True
    device = torch.device("cuda" if args.cuda else "cpu")
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLinePlotter(env_name='m_trainer')
        visdom_logger.clear()
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None

    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)

        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
Пример #6
0
class Trainer:
    def __init__(self,
                 model,
                 dataloader,
                 loss_func,
                 optimizer,
                 scheduler,
                 n_epochs,
                 acc_func=None,
                 train_name='default',
                 resume_dict=None,
                 ckp_dir=None,
                 resume_ep='latest',
                 logger=None,
                 tb_log_dir=None,
                 log_step_interval=100,
                 comment=None,
                 val_dataloader=None):
        '''
        If resume_dict is not None then load from resume_dict and ignore files stored in ckp_dir
        '''
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3]
        self.model = model
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        self.dataloader = dataloader
        self.val_dataloader = val_dataloader
        self.loss_func = loss_func
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.n_epochs = n_epochs
        self.acc_func = acc_func
        self.logger = logger
        self.ckp_dir = ckp_dir
        self.tb_log_dir = tb_log_dir
        self.log_step_interval = log_step_interval
        self.start_epoch = 0
        self.epoch_loss = None
        self.epoch_acc = None
        self.epoch_num = None
        self.step_loss = None
        self.step_acc = None
        self.step_num = None
        self.val_loss = None
        self.val_nn_acc = None
        self.steps_per_epoch = len(self.dataloader)
        if self.ckp_dir is None:
            self.ckp_dir = '/media/zli/Seagate Backup Plus Drive/trained_models/pytorch-py3/checkpoints/{}/{}'.format(
                train_name, self.timestamp)
        if not os.path.isdir(self.ckp_dir):
            os.makedirs(self.ckp_dir)
        # find model
        ckp_dir_file_names = os.listdir(self.ckp_dir)
        if model is None:
            if 'jit_model.pth' in ckp_dir_file_names:
                self.model = torch.jit.load(
                    os.path.join(self.ckp_dir, 'jit_model.pth'))
            else:
                raise ValueError('No model found')
        # find resume ep
        if resume_ep == 'latest':
            epoch_file_list = glob(os.path.join(self.ckp_dir, 'epoch_*.pth'))
            if len(epoch_file_list):
                prog = re.compile("epoch_([0-9]+).pth")
                epoch_number_list = [
                    int(prog.findall(i)[0]) for i in epoch_file_list
                ]
                resume_ep = max(epoch_number_list)
                print(f'Latest epoch {resume_ep} found.')
            else:
                resume_ep = None
        else:
            assert resume_ep.isnumeric(), ValueError(
                "resume_ep must be numeric or 'latest'.")
            assert os.path.isfile(os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth")), \
                   FileNotFoundError(f"epoch {resume_ep} not found in {self.ckp_dir}")
            resume_ep = int(resume_ep)
        if resume_dict is not None:
            self.load_epoch(resume_dict)
        elif resume_ep is not None:
            self.load_epoch(
                os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth"))
        # Logger
        if self.logger is None:
            if self.tb_log_dir is None:
                self.tb_log_dir = 'tb_logs/{}'.format(train_name)
            self.tb_log_dir = os.path.join(self.tb_log_dir, self.timestamp)
            self.logger = TensorBoardLogger(self.tb_log_dir)
        self.logger.log_init(self)
        self.record_params = [
            'loss_func', 'step_loss', 'step_acc', 'val_loss', 'val_nn_acc'
        ]
        self.epoch_finish_hook = []

    def train(self):
        for epoch in range(self.start_epoch, self.n_epochs):
            self.epoch_num = epoch
            current_lr = [
                group['lr'] for group in self.optimizer.param_groups
            ][0]
            print('---- start epoch: {}/{}\tlearning rate:{:.2E} ----'.format(
                self.epoch_num, self.n_epochs, current_lr))
            self.epoch_loss, self.epoch_acc = self.train_epoch()
            self.val_loss, self.val_nn_acc = self.validation_epoch()
            for fn in self.epoch_finish_hook:
                fn(self)
            self.scheduler.step()
            print(
                'end epoch: {}/{}\ttrain loss: {:.2f}\tvalidation loss: {}\tvalidation nn accuracy: {}\n'
                .format(self.epoch_num, self.n_epochs, self.epoch_loss,
                        self.val_loss, self.val_nn_acc))
            self.save_epoch()

    def validation_epoch(self):
        val_step_loss = 0
        vector_list = []
        label_list = []

        self.model.eval()
        tbar = tqdm(enumerate(self.val_dataloader),
                    total=len(self.val_dataloader))
        for batch_idx, (data, target) in tbar:
            self.step_num = batch_idx
            self.optimizer.zero_grad()
            data = tuple(data)
            if type(data) is tuple:
                data = tuple(d.cuda() for d in data)
                model_output = [self.model(d) for d in data]
            elif type(data) is torch.Tensor:
                model_output = self.model(data.cuda())
            else:
                raise TypeError(f'Unknown type of input data{type(data)}')
            loss = self.loss_func(model_output, target)
            val_step_loss += loss.item()

            for (o, t) in zip(model_output, target):
                o = torch.nn.functional.normalize(o, p=2, dim=1)
                c = o.cpu().squeeze(0).detach().numpy()
                vector_list.append(c)
                label_list.append(t)
            #print(val_step_loss)
        val_loss = val_step_loss / len(self.val_dataloader)
        labels = {}
        for l in label_list:
            for key, val in l.items():
                if key not in labels:
                    labels[key] = []
                labels[key].extend(val)
        #vectors_norm = [torch.nn.functional.normalize(d, p=2, dim=1) for d in vector_list]
        vectors = vector_list

        # compute nearest neighbor accuracy
        total_nums = len(vectors)
        total_correct = 0

        for i in range(total_nums):
            base_vec = vectors[i]
            base_label = labels['plot'][i]
            base_sensor = labels['sensor'][i]
            dist_rec = np.zeros(total_nums)
            for j in range(total_nums):
                pair_sensor = labels['sensor'][j]
                loop_vec = vectors[j]
                if base_sensor is pair_sensor:
                    dist_rec[j] = float('inf')
                else:
                    #dist_rec[j] = np.linalg.norm(loop_vec-base_vec)
                    # dot product
                    dist_rec[j] = (1 - np.matmul(loop_vec, base_vec))

            min_index = dist_rec.argsort()[:1]

            check_flag = False
            for ind in range(len(min_index)):
                val = min_index[ind]
                pair_label = labels['plot'][val]
                if not check_flag:
                    if (base_label
                            == pair_label) and (base_sensor != pair_sensor):
                        total_correct += 1
                        check_flag = True

        val_nn_acc = float(total_correct) / total_nums
        print('total correct: {}\n'.format(total_correct))

        return val_loss, val_nn_acc

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        total_acc = 0
        tbar = tqdm(enumerate(self.dataloader), total=len(self.dataloader))
        for batch_idx, (data, target) in tbar:
            self.step_num = batch_idx
            self.optimizer.zero_grad()
            data = tuple(data)
            if type(data) is tuple:
                data = tuple(d.cuda() for d in data)
                model_output = [self.model(d) for d in data]
            elif type(data) is torch.Tensor:
                model_output = self.model(data.cuda())
            else:
                raise TypeError(f'Unknown type of input data{type(data)}')
            loss = self.loss_func(model_output, target)
            self.step_loss = loss.item()
            total_loss += self.step_loss
            loss.backward()
            self.optimizer.step()
            self.logger.step(self)
            if self.acc_func is None:
                tbar.set_description('loss: {:.2f}'.format(self.step_loss))
            else:
                acc = self.acc_func(model_output, target)
                self.step_acc = acc
                total_acc += self.step_accq
                tbar.set_description('loss: {:.2f}, acc: {:.2f}'.format(
                    self.step_loss, acc))
        epoch_loss = total_loss / len(self.dataloader)
        epoch_acc = None
        if self.acc_func is not None:
            epoch_acc = total_acc / len(self.dataloader)
        return epoch_loss, epoch_acc

    def loss_plot(self):
        tbar = tqdm(enumerate(self.dataloader), total=len(self.dataloader))
        date_list = []
        loss_list = []
        ind = 0
        max_num = 3000
        for batch_idx, (data, target) in tbar:
            ind += 1
            if ind > max_num:
                break
            data = tuple(data)
            if type(data) is tuple:
                data = tuple(d.cuda() for d in data)
                model_output = [self.model(d) for d in data]
            elif type(data) is torch.Tensor:
                model_output = self.model(data.cuda())
            else:
                raise TypeError(f'Unknown type of input data{type(data)}')
            loss = self.loss_func(model_output, target)
            loss_list.append(loss.item())
            date_list.append(target[0]['scan_date'].item())

        # draw plot
        total_nums = len(loss_list)

        min_date = min(date_list)
        max_date = max(date_list)

        hist = np.zeros(max_date - min_date + 1)
        hist_count = np.zeros(max_date - min_date + 1)

        for i in range(total_nums):
            ind_date = date_list[i] - min_date
            loss = loss_list[i]
            hist[ind_date] = hist[ind_date] + loss
            hist_count[ind_date] = hist_count[ind_date] + 1

        np_out = hist / hist_count
        #np_out[ np_out==0 ] = np.nan
        x = linspace(0, max_date - min_date, max_date - min_date + 1)
        y = np_out
        plt.plot(x, y, marker=".", markersize=40)
        plt.ylim(0, 1)
        plt.show()

        return

    def add_epoch_hook(self, func):
        self.epoch_finish_hook.append(func)
        return len(self.epoch_finish_hook) - 1

    def remove_epoch_hook(self, i):
        self.epoch_finish_hook.pop(i)

    def save_jit_model(self):
        scripted_model = torch.jit.script(self.model)
        torch.jit.save(scripted_model,
                       os.path.join(self.ckp_dir, 'jit_model.pth'))

    def save_epoch(self):
        save_path = os.path.join(self.ckp_dir,
                                 'epoch_{}.pth'.format(self.epoch_num))
        torch.save(
            {
                'epoch': self.epoch_num,
                'model_state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }, save_path)

    def load_epoch(self, resume_dict):
        if type(resume_dict) is str:
            resume_dict = torch.load(resume_dict)
        self.start_epoch = resume_dict['epoch'] + 1
        self.model.load_state_dict(resume_dict['model_state_dict'])
        self.optimizer.load_state_dict(resume_dict['optimizer'])
        self.scheduler.load_state_dict(resume_dict['scheduler'])
Пример #7
0
def test_dataloader():
    args = parser.parse_args()

    # Set seeds for determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    device = torch.device("cuda" if args.cuda else "cpu")
    if args.mixed_precision and not args.cuda:
        raise ValueError(
            'If using mixed precision training, CUDA must be enabled!')
    args.distributed = args.world_size > 1
    main_proc = True
    device = torch.device("cuda" if args.cuda else "cpu")
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None

    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)

        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
                start_iter = 0
            else:
                start_iter += 1
            avg_loss = int(package.get('avg_loss', 0))
            loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], \
                                                     package['wer_results']
            if main_proc and args.visdom:  # Add previous scores to visdom graph
                visdom_logger.load_previous_values(start_epoch, package)
            if main_proc and args.tensorboard:  # Previous scores to tensorboard logs
                tensorboard_logger.load_previous_values(start_epoch, package)

        print("Loading label from %s" % args.labels_path)
        with open(args.labels_path) as label_file:
            labels = str(''.join(json.load(label_file)))

        audio_conf = dict(sample_rate=args.sample_rate,
                          window_size=args.window_size,
                          window_stride=args.window_stride,
                          window=args.window,
                          noise_dir=args.noise_dir,
                          noise_prob=args.noise_prob,
                          noise_levels=(args.noise_min, args.noise_max))
    else:
        print("must load model!!!")
        exit()

    # decoder = GreedyDecoder(labels)
    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)

    train_sampler = BucketingSampler(train_dataset, batch_size=args.batch_size)
    train_loader = AudioDataLoader(train_dataset,
                                   num_workers=args.num_workers,
                                   batch_sampler=train_sampler)

    for i, (data) in enumerate(train_loader, start=start_iter):
        # 获取初始输入
        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
        inputs = inputs.to(device)
        size = inputs.size()
        print(size)
        # 初始化模型
        model = M_Noise_Deepspeech(package, size)
        for para in model.deepspeech_net.parameters():
            para.requires_grad = False
        model = model.to(device)

        # 获取初始输出
        out_star = model.deepspeech_net(inputs, input_sizes)[0]
        out_star = out_star.transpose(0, 1)  # TxNxH
        float_out_star = out_star.float()
        break

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=1e-5)
    print(model)
Пример #8
0
        # TODO - what about target net update cnt
        target_net.load_state_dict(model_dict['target_net_state_dict'])
        policy_net.load_state_dict(model_dict['policy_net_state_dict'])
        opt.load_state_dict(model_dict['optimizer'])
        print("loaded model state_dicts")
        # TODO cant load buffer yet
        if args.buffer_loadpath == '':
            args.buffer_loadpath = args.model_loadpath.replace(
                '.pkl', '_train_buffer.pkl')
            print("auto loading buffer from:%s" % args.buffer_loadpath)
            rbuffer.load(args.buffer_loadpath)
    info['args'] = args
    write_info_file(info, model_base_filepath, total_steps)
    random_state = np.random.RandomState(info["SEED"])

    board_logger = TensorBoardLogger(model_base_filedir)
    last_target_update = 0
    print("Starting training")
    all_rewards = []

    epsilon_by_frame = lambda frame_idx: info['EPSILON_MIN'] + (info[
        'EPSILON_MAX'] - info['EPSILON_MIN']) * math.exp(-1. * frame_idx /
                                                         info['EPSILON_DECAY'])
    for epoch_num in range(epoch_start, info['N_EPOCHS']):
        ep_reward, total_steps, etime = run_training_episode(
            epoch_num, total_steps)
        all_rewards.append(ep_reward)
        overall_time += etime
        last_mean = np.mean(all_rewards[-100:])
        board_logger.scalar_summary("avg reward last 100 episodes", epoch_num,
                                    last_mean)
Пример #9
0
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, cer_results, wer_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)

    best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir,
                                               args.log_params)
    writer = SummaryWriter(args.log_dir)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None
    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from,
                             map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = model.labels

        audio_conf = model.audio_conf
        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get(
                'epoch', 1)) - 1  # Index start at 0 for training
Пример #10
0
    args.distributed = args.world_size > 1
    main_proc = True
    device = torch.device("cuda" if args.cuda else "cpu")
    if args.distributed:
        if args.gpu_rank:
            torch.cuda.set_device(int(args.gpu_rank))
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    loss_results, val_loss_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_val_loss = None
    if main_proc and args.tensorboard:
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir, args.log_params)

    avg_loss, start_epoch, start_iter, optim_state = 0, 0, 0, None
    if args.continue_from:  # Starting from previous model
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from, map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = model.labels
        audio_conf = model.audio_conf
        if not args.finetune:  # Don't want to restart training
            optim_state = package['optim_dict']
            start_epoch = int(package.get('epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
                start_epoch += 1  # We saved model after epoch finished, start at the next epoch.
                start_iter = 0
Пример #11
0
            torch.cuda.set_device(int(args.gpu_rank))
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        print('Initiated process group')
        main_proc = args.rank == 0  # Only the first proc should save models
    save_folder = args.save_folder
    os.makedirs(save_folder, exist_ok=True)  # Ensure save folder exists

    # loss_results, acc_results, std_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
    #     args.epochs)
    # best_wer = None
    if main_proc and args.visdom:
        visdom_logger = VisdomLogger(args.id, args.epochs)
    if main_proc and args.tensorboard:
        # import pdb; pdb.set_trace()
        tensorboard_logger = TensorBoardLogger(args.id, args.log_dir, args.log_params, comment=sufix)

    avg_loss, start_epoch, start_iter, optim_state, amp_state = 0, 0, 0, None, None
    if args.continue_from:  # Starting from previous model (fora la no loop de cross-val)
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from, map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        labels = model.labels
        audio_conf = model.audio_conf
        if not args.finetune:  # Don't want to restart training
            # Nao eh finetuning. Entao os parametros de otimizacao (repare que nao os pesos,que sao o state_dict e esses foram carregados) nao voltam ao estado inicial. O modelo eh retreinado mas com a inicializacao dos pesos de antes. TO DO: quando for usar pro meu, checar as coisas de batch normalization, a running average nao pode ser usada por exemplo, ver se ela ta no state ou no optim.
            optim_state = package['optim_dict']
            # amp_state = package['amp']  # what is it?
            start_epoch = int(package.get('epoch', 1)) - 1  # Index start at 0 for training
            start_iter = package.get('iteration', None)
            if start_iter is None:
Пример #12
0
    def train(self, **kwargs):
        """
        Run optimization to train the model.

        Parameters
        ----------


        """
        world_size = kwargs.pop('world_size', 1)
        gpu_rank = kwargs.pop('gpu_rank', 0)
        rank = kwargs.pop('rank', 0)
        dist_backend = kwargs.pop('dist_backend', 'nccl')
        dist_url = kwargs.pop('dist_url', None)

        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '1234'

        main_proc = True
        self.distributed = world_size > 1

        if self.distributed:
            if self.gpu_rank:
                torch.cuda.set_device(int(gpu_rank))
            dist.init_process_group(backend=dist_backend,
                                    init_method=dist_url,
                                    world_size=world_size,
                                    rank=rank)
            print('Initiated process group')
            main_proc = rank == 0  # Only the first proc should save models

        if main_proc and self.tensorboard:
            tensorboard_logger = TensorBoardLogger(self.id,
                                                   self.log_dir,
                                                   self.log_params,
                                                   comment=self.sufix)

        if self.distributed:
            train_sampler = DistributedBucketingSampler(
                self.data_train,
                batch_size=self.batch_size,
                num_replicas=world_size,
                rank=rank)
        else:
            if self.sampler_type == 'bucketing':
                train_sampler = BucketingSampler(self.data_train,
                                                 batch_size=self.batch_size,
                                                 shuffle=True)
            if self.sampler_type == 'random':
                train_sampler = RandomBucketingSampler(
                    self.data_train, batch_size=self.batch_size)

        print("Shuffling batches for the following epochs..")
        train_sampler.shuffle(self.start_epoch)

        train_loader = AudioDataLoader(self.data_train,
                                       num_workers=self.num_workers,
                                       batch_sampler=train_sampler)
        val_loader = AudioDataLoader(self.data_val,
                                     batch_size=self.batch_size_val,
                                     num_workers=self.num_workers,
                                     shuffle=True)

        if self.tensorboard and self.generate_graph:  # TO DO get some audios also
            with torch.no_grad():
                inputs, targets, input_percentages, target_sizes = next(
                    iter(train_loader))
                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
                tensorboard_logger.add_image(inputs,
                                             input_sizes,
                                             targets,
                                             network=self.model)

        self.model = self.model.to(self.device)
        parameters = self.model.parameters()

        if self.update_rule == 'adam':
            optimizer = torch.optim.Adam(parameters,
                                         lr=self.lr,
                                         weight_decay=self.reg)
        if self.update_rule == 'sgd':
            optimizer = torch.optim.SGD(parameters,
                                        lr=self.lr,
                                        weight_decay=self.reg)

        self.model, self.optimizer = amp.initialize(
            self.model,
            optimizer,
            opt_level=self.opt_level,
            keep_batchnorm_fp32=self.keep_batchnorm_fp32,
            loss_scale=self.loss_scale)

        if self.optim_state is not None:
            self.optimizer.load_state_dict(self.optim_state)

        if self.amp_state is not None:
            amp.load_state_dict(self.amp_state)

        if self.distributed:
            self.model = DistributedDataParallel(self.model)

        print(self.model)

        if self.criterion_type == 'cross_entropy_loss':
            self.criterion = torch.nn.CrossEntropyLoss()

        #  Useless for now because I don't save.
        accuracies_train_iters = []
        losses_iters = []

        avg_loss = 0
        batch_time = AverageMeter()
        epoch_time = AverageMeter()
        losses = AverageMeter()

        start_training = time.time()
        for epoch in range(self.start_epoch, self.num_epochs):
            print("Start epoch..")

            # Put model in train mode
            self.model.train()

            y_true_train_epoch = np.array([])
            y_pred_train_epoch = np.array([])

            start_epoch = time.time()
            for i, (data) in enumerate(train_loader, start=0):
                start_batch = time.time()

                print('Start batch..')

                if i == len(train_sampler):  # QUE pq isso deus
                    break

                inputs, targets, input_percentages, _ = data

                input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                output, loss_value = self._step(inputs, input_sizes, targets)

                print('Step finished.')

                avg_loss += loss_value

                with torch.no_grad():
                    y_pred = self.decoder.decode(output.detach()).cpu().numpy()

                    # import pdb; pdb.set_trace()

                    y_true_train_epoch = np.concatenate(
                        (y_true_train_epoch, targets.cpu().numpy()
                         ))  # maybe I should do it with tensors?
                    y_pred_train_epoch = np.concatenate(
                        (y_pred_train_epoch, y_pred))

                inputs_size = inputs.size(0)
                del output, inputs, input_percentages

                if self.intra_epoch_sanity_check:
                    with torch.no_grad():
                        acc, _ = self.check_accuracy(targets.cpu().numpy(),
                                                     y_pred=y_pred)
                        accuracies_train_iters.append(acc)
                        losses_iters.append(loss_value)

                        cm = confusion_matrix(targets.cpu().numpy(),
                                              y_pred,
                                              labels=self.labels)
                        print('[it %i/%i] Confusion matrix train step:' %
                              ((i + 1, len(train_sampler))))
                        print(pd.DataFrame(cm))

                        if self.tensorboard:
                            tensorboard_logger.update(
                                len(train_loader) * epoch + i + 1, {
                                    'Loss/through_iterations': loss_value,
                                    'Accuracy/train_through_iterations': acc
                                })

                del targets

                batch_time.update(time.time() - start_batch)

            epoch_time.update(time.time() - start_epoch)
            losses.update(loss_value, inputs_size)

            # Write elapsed time (and loss) to terminal
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Epoch {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      (epoch + 1), (i + 1),
                      len(train_sampler),
                      batch_time=batch_time,
                      data_time=epoch_time,
                      loss=losses))

            # Loss log
            avg_loss /= len(train_sampler)
            self.loss_epochs.append(avg_loss)

            # Accuracy train log
            acc_train, _ = self.check_accuracy(y_true_train_epoch,
                                               y_pred=y_pred_train_epoch)
            self.accuracy_train_epochs.append(acc_train)

            # Accuracy val log
            with torch.no_grad():
                y_pred_val = np.array([])
                targets_val = np.array([])
                for data in val_loader:
                    inputs, targets, input_percentages, _ = data
                    input_sizes = input_percentages.mul_(int(
                        inputs.size(3))).int()
                    _, y_pred_val_batch = self.check_accuracy(
                        targets.cpu().numpy(),
                        inputs=inputs,
                        input_sizes=input_sizes)
                    y_pred_val = np.concatenate((y_pred_val, y_pred_val_batch))
                    targets_val = np.concatenate(
                        (targets_val, targets.cpu().numpy()
                         ))  # TO DO: think of a smarter way to do this later
                    del inputs, targets, input_percentages

            # import pdb; pdb.set_trace()
            acc_val, y_pred_val = self.check_accuracy(targets_val,
                                                      y_pred=y_pred_val)
            self.accuracy_val_epochs.append(acc_val)
            cm = confusion_matrix(targets_val, y_pred_val, labels=self.labels)
            print('Confusion matrix validation:')
            print(pd.DataFrame(cm))

            # Write epoch stuff to tensorboard
            if self.tensorboard:
                tensorboard_logger.update(
                    epoch + 1, {'Loss/through_epochs': avg_loss},
                    parameters=self.model.named_parameters)

                tensorboard_logger.update(epoch + 1, {
                    'train': acc_train,
                    'validation': acc_val
                },
                                          together=True,
                                          name='Accuracy/through_epochs')

            # Keep track of the best model
            if acc_val > self.best_acc_val:
                self.best_acc_val = acc_val
                self.best_params = {}
                for k, v in self.model.named_parameters(
                ):  # TO DO: actually copy model and save later? idk..
                    self.best_params[k] = v.clone()

            # Anneal learning rate. TO DO: find better way to this this specific to every parameter as cs231n does.
            for g in self.optimizer.param_groups:
                g['lr'] = g['lr'] / self.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

            # Shuffle batches order
            print("Shuffling batches...")
            train_sampler.shuffle(epoch)

            # Rechoose batches elements
            if self.sampler_type == 'random':
                train_sampler.recompute_bins()

        end_training = time.time()

        if self.tensorboard:
            tensorboard_logger.close()

        print('Elapsed time in training: %.02f ' %
              ((end_training - start_training) / 60.0))
class Trainer:
    def __init__(self,
                 model,
                 dataloader,
                 loss_func,
                 optimizer,
                 scheduler,
                 n_epochs,
                 acc_func=None,
                 train_name='default',
                 resume_dict=None,
                 ckp_dir=None,
                 resume_ep='latest',
                 logger=None,
                 tb_log_dir=None,
                 log_step_interval=100,
                 comment=None):
        '''
        If resume_dict is not None then load from resume_dict and ignore files stored in ckp_dir
        '''
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3]
        self.model = model
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        self.dataloader = dataloader
        self.loss_func = loss_func
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.n_epochs = n_epochs
        self.acc_func = acc_func
        self.logger = logger
        self.ckp_dir = ckp_dir
        self.tb_log_dir = tb_log_dir
        self.log_step_interval = log_step_interval
        self.start_epoch = 0
        self.epoch_loss = None
        self.epoch_acc = None
        self.epoch_num = None
        self.step_loss = None
        self.step_acc = None
        self.step_num = None
        self.steps_per_epoch = len(self.dataloader)
        if self.ckp_dir is None:
            self.ckp_dir = '/media/zli/Seagate Backup Plus Drive/trained_models/pytorch-py3/checkpoints/{}/{}'.format(
                train_name, self.timestamp)
        if not os.path.isdir(self.ckp_dir):
            os.makedirs(self.ckp_dir)
        # find model
        ckp_dir_file_names = os.listdir(self.ckp_dir)
        if model is None:
            if 'jit_model.pth' in ckp_dir_file_names:
                self.model = torch.jit.load(
                    os.path.join(self.ckp_dir, 'jit_model.pth'))
            else:
                raise ValueError('No model found')
        # find resume ep
        if resume_ep == 'latest':
            epoch_file_list = glob(os.path.join(self.ckp_dir, 'epoch_*.pth'))
            if len(epoch_file_list):
                prog = re.compile("epoch_([0-9]+).pth")
                epoch_number_list = [
                    int(prog.findall(i)[0]) for i in epoch_file_list
                ]
                resume_ep = max(epoch_number_list)
                print(f'Latest epoch {resume_ep} found.')
            else:
                resume_ep = None
        else:
            assert resume_ep.isnumeric(), ValueError(
                "resume_ep must be numeric or 'latest'.")
            assert os.path.isfile(os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth")), \
                   FileNotFoundError(f"epoch {resume_ep} not found in {self.ckp_dir}")
            resume_ep = int(resume_ep)
        if resume_dict is not None:
            self.load_epoch(resume_dict)
        elif resume_ep is not None:
            self.load_epoch(
                os.path.join(self.ckp_dir, f"epoch_{resume_ep}.pth"))
        # Logger
        if self.logger is None:
            if self.tb_log_dir is None:
                self.tb_log_dir = 'tb_logs/{}'.format(train_name)
            self.tb_log_dir = os.path.join(self.tb_log_dir, self.timestamp)
            self.logger = TensorBoardLogger(self.tb_log_dir)
        self.logger.log_init(self)
        self.record_params = ['loss_func', 'step_loss', 'step_acc']
        self.epoch_finish_hook = []

    def train(self):
        for epoch in range(self.start_epoch, self.n_epochs):
            self.epoch_num = epoch
            current_lr = [
                group['lr'] for group in self.optimizer.param_groups
            ][0]
            print('---- start epoch: {}/{}\tlearning rate:{:.2E} ----'.format(
                self.epoch_num, self.n_epochs, current_lr))
            self.epoch_loss, self.epoch_acc = self.train_epoch()
            for fn in self.epoch_finish_hook:
                fn(self)
            # tb_writer.add_scalar('epoch/lr', scheduler.get_lr()[0], epoch)
            # tb_writer.add_scalar('epoch/train_loss',epoch_loss, epoch)
            # if epoch_acc is not None:
            #     tb_writer.add_scalar('epoch/train_acc', epoch_acc, epoch)
            # tb_writer.flush()
            self.scheduler.step()
            print(
                'end epoch: {}/{}\ttrain loss: {:.2f}\ttrain acc: {}\n'.format(
                    self.epoch_num, self.n_epochs, self.epoch_loss,
                    self.epoch_acc))
            self.save_epoch()

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        total_acc = 0
        tbar = tqdm(enumerate(self.dataloader), total=len(self.dataloader))
        for batch_idx, (data, target) in tbar:
            self.step_num = batch_idx
            self.optimizer.zero_grad()
            if type(data) is tuple:
                data = tuple(d.cuda() for d in data)
                model_output = [self.model(d) for d in data]
            elif type(data) is torch.Tensor:
                model_output = self.model(data.cuda())
            else:
                raise TypeError(f'Unknown type of input data{type(data)}')
            loss = self.loss_func(model_output, target)
            self.step_loss = loss.item()
            total_loss += self.step_loss
            loss.backward()
            self.optimizer.step()
            self.logger.step(self)
            if self.acc_func is None:
                tbar.set_description('loss: {:.2f}'.format(self.step_loss))
            else:
                acc = self.acc_func(model_output, target)
                self.step_acc = acc
                total_acc += self.step_accq
                tbar.set_description('loss: {:.2f}, acc: {:.2f}'.format(
                    self.step_loss, acc))
        epoch_loss = total_loss / len(self.dataloader)
        epoch_acc = None
        if self.acc_func is not None:
            epoch_acc = total_acc / len(self.dataloader)
        return epoch_loss, epoch_acc

    def add_epoch_hook(self, func):
        self.epoch_finish_hook.append(func)
        return len(self.epoch_finish_hook) - 1

    def remove_epoch_hook(self, i):
        self.epoch_finish_hook.pop(i)

    def save_jit_model(self):
        scripted_model = torch.jit.script(self.model)
        torch.jit.save(scripted_model,
                       os.path.join(self.ckp_dir, 'jit_model.pth'))

    def save_epoch(self):
        save_path = os.path.join(self.ckp_dir,
                                 'epoch_{}.pth'.format(self.epoch_num))
        torch.save(
            {
                'epoch': self.epoch_num,
                'model_state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }, save_path)

    def load_epoch(self, resume_dict):
        if type(resume_dict) is str:
            resume_dict = torch.load(resume_dict)
        self.start_epoch = resume_dict['epoch'] + 1
        self.model.load_state_dict(resume_dict['model_state_dict'])
        self.optimizer.load_state_dict(resume_dict['optimizer'])
        self.scheduler.load_state_dict(resume_dict['scheduler'])