Ejemplo n.º 1
0
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    train_iter = iter(train_loader)
    for i in range(5000):
        print(train_iter.next()[0].shape)
    # read test set
    # images will be resize to 32*160
    test_dataset = dataset.lmdbDataset(root=opt.valroot, rgb=params.rgb)

    nclass = len(params.alphabet) + 1
    nc = 1

    converter = utils.strLabelConverter(params.alphabet)
    criterion = CTCLoss(size_average=False, length_average=False)

    # cnn and rnn
    image = torch.FloatTensor(params.batchSize, 3, params.imgH, params.imgH)
    text = torch.IntTensor(params.batchSize * 5)
    length = torch.IntTensor(params.batchSize)

    crnn = crnn.CRNN(params.imgH, nc, nclass, params.nh)
    crnn = torch.nn.DataParallel(crnn)
    if opt.cuda:
        crnn.cuda()
        image = image.cuda()
        criterion = criterion.cuda()

    crnn.apply(weights_init)
    if params.crnn != '':
Ejemplo n.º 2
0
def main():
    global logger
    log_file = './log/train.log'
    logger = init_logger(log_file)

    from visdom import Visdom
    viz = Visdom()
    opts = [
        dict(title="Timit Spectrum_CNN" + " Loss",
             ylabel='Loss',
             xlabel='Epoch'),
        dict(title="Timit Spectrum_CNN" + " CER on Train",
             ylabel='CER',
             xlabel='Epoch'),
        dict(title='Timit Spectrum_CNN' + ' CER on DEV',
             ylabel='DEV CER',
             xlabel='Epoch')
    ]
    viz_window = [None, None, None]

    init_lr = 0.001
    num_epoches = 30
    least_train_epoch = 5
    end_adjust_acc = 0.5
    decay = 0.5
    count = 0
    learning_rate = init_lr
    batch_size = 8
    weight_decay = 0.005
    model_type = 'CNN_LSTM_CTC'

    params = {
        'num_epoches': num_epoches,
        'least_train_epoch': least_train_epoch,
        'end_adjust_acc': end_adjust_acc,
        'decay': decay,
        'learning_rate': init_lr,
        'weight_decay': weight_decay,
        'batch_size': batch_size
    }

    acc_best = -100
    adjust_rate_flag = False
    stop_train = False

    train_dataset = myDataset(data_set='train',
                              feature_type="spectrum",
                              out_type='phone',
                              n_feats=201)
    dev_dataset = myDataset(data_set="dev",
                            feature_type="spectrum",
                            out_type='phone',
                            n_feats=201)

    decoder = Decoder(dev_dataset.int2phone, space_idx=-1, blank_index=0)

    rnn_input_size = train_dataset.n_feats

    if model_type == 'CNN_LSTM_CTC':
        model = CNN_LSTM_CTC(rnn_input_size=rnn_input_size,
                             rnn_hidden_size=256,
                             rnn_layers=4,
                             rnn_type=nn.LSTM,
                             bidirectional=True,
                             batch_norm=True,
                             num_class=48,
                             drop_out=0)
        train_loader = myCNNDataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4,
                                       pin_memory=False)
        dev_loader = myCNNDataLoader(dev_dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=4,
                                     pin_memory=False)
    else:
        model = CTC_RNN(rnn_input_size=rnn_input_size,
                        rnn_hidden_size=256,
                        rnn_layers=4,
                        rnn_type=nn.LSTM,
                        bidirectional=True,
                        batch_norm=True,
                        num_class=48,
                        drop_out=0)
        train_loader = myDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=False)
        dev_loader = myDataLoader(dev_dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=4,
                                  pin_memory=False)

    if USE_CUDA:
        model = model.cuda()

    loss_fn = CTCLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 weight_decay=weight_decay)

    adjust_time = 0
    start_time = time.time()
    loss_results = []
    training_cer_results = []
    dev_cer_results = []

    while not stop_train:
        if count >= num_epoches:
            break
        count += 1

        if adjust_rate_flag:
            learning_rate *= decay
            for param in optimizer.param_groups:
                param['lr'] *= decay

        print("Start training epoch: %d, learning_rate: %.5f" %
              (count, learning_rate))
        logger.info("Start training epoch: %d, learning_rate: %.5f" %
                    (count, learning_rate))

        loss = train(model, train_loader, loss_fn, optimizer, print_every=20)
        loss_results.append(loss)
        cer = dev(model, train_loader, decoder)
        print("cer on training set is %.4f" % cer)
        logger.info("cer on training set is %.4f" % cer)
        training_cer_results.append(cer)
        acc = dev(model, dev_loader, decoder)
        dev_cer_results.append(acc)

        model_path_accept = './log/epoch' + str(count) + '_lr' + str(
            learning_rate) + '_cv' + str(acc) + '.pkl'
        #model_path_reject = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'_rejected.pkl'

        if adjust_time == 8:
            stop_train = True

        ##10轮迭代之后,开始调整学习率
        if count >= least_train_epoch:
            if acc > (acc_best + end_adjust_acc):
                model_state = model.state_dict()
                op_state = optimizer.state_dict()
                adjust_rate_flag = False
                acc_best = acc
                #torch.save(model_state, model_path_accept)
            elif (acc > acc_best):
                model_state = model.state_dict()
                op_state = optimizer.state_dict()
                adjust_rate_flag = True
                adjust_time += 1
                acc_best = acc
                #torch.save(model_state, model_path_accept)
            elif (acc <= acc_best):
                adjust_rate_flag = True
                adjust_time += 1
                #torch.save(model.state_dict(), model_path_reject)
                model.load_state_dict(model_state)
                optimizer.load_state_dict(op_state)

        time_used = (time.time() - start_time) / 60
        print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
              (count, acc, time_used))
        logger.info("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
                    (count, acc, time_used))
        x_axis = range(count)
        y_axis = [
            loss_results[0:count], training_cer_results[0:count],
            dev_cer_results[0:count]
        ]
        for x in range(len(viz_window)):
            if viz_window[x] is None:
                viz_window[x] = viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    opts=opts[x],
                )
            else:
                viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    win=viz_window[x],
                    update='replace',
                )

    print("End training, best cv acc is: %.4f" % acc_best)
    logger.info("End training, best cv acc is: %.4f" % acc_best)
    best_path = './log/best_model' + '_cv' + str(acc_best) + '.pkl'
    params['epoch'] = count
    params['feature_type'] = train_dataset.feature_type
    params['n_feats'] = train_dataset.n_feats
    params['out_type'] = train_dataset.out_type
    torch.save(
        CTC_RNN.save_package(model,
                             optimizer=optimizer,
                             epoch=params,
                             loss_results=loss_results,
                             training_cer_results=training_cer_results,
                             dev_cer_results=dev_cer_results), best_path)
Ejemplo n.º 3
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        help='Dataset type, must be one of csv or coco.',
                        default="csv")
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument(
        '--csv_train',
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)',
                        default="binary_class.csv")
    parser.add_argument(
        '--csv_val',
        help=
        'Path to file containing validation annotations (optional, see readme)'
    )

    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=18)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=500)
    parser.add_argument('--epochs_only_det',
                        help='Number of epochs to train detection part',
                        type=int,
                        default=1)
    parser.add_argument('--max_epochs_no_improvement',
                        help='Max epochs without improvement',
                        type=int,
                        default=100)
    parser.add_argument('--pretrained_model',
                        help='Path of .pt file with pretrained model',
                        default='esposallescsv_retinanet_0.pt')
    parser.add_argument('--model_out',
                        help='Path of .pt file with trained model to save',
                        default='trained')

    parser.add_argument('--score_threshold',
                        help='Score above which boxes are kept',
                        type=float,
                        default=0.5)
    parser.add_argument('--nms_threshold',
                        help='Score above which boxes are kept',
                        type=float,
                        default=0.2)
    parser.add_argument('--max_boxes',
                        help='Max boxes to be fed to recognition',
                        default=95)
    parser.add_argument('--seg_level',
                        help='[line, word], to choose anchor aspect ratio',
                        default='word')
    parser.add_argument(
        '--early_stop_crit',
        help='Early stop criterion, detection (map) or transcription (cer)',
        default='cer')
    parser.add_argument('--max_iters_epoch',
                        help='Max steps per epoch (for debugging)',
                        default=1000000)
    parser.add_argument('--train_htr',
                        help='Train recognition or not',
                        default='True')
    parser.add_argument('--train_det',
                        help='Train detection or not',
                        default='True')
    parser.add_argument(
        '--htr_gt_box',
        help='Train recognition branch with box gt (for debugging)',
        default='False')

    parser = parser.parse_args(args)

    if parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train')

        dataset_name = parser.csv_train.split("/")[-2]

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    # Files for training log

    experiment_id = str(time.time()).split('.')[0]
    valid_cer_f = open(experiment_id + '_valid_CER.txt', 'w')
    for arg in vars(parser):
        if getattr(parser, arg) is not None:
            valid_cer_f.write(
                str(arg) + ' ' + str(getattr(parser, arg)) + '\n')
    valid_cer_f.close()

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=1,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=3,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=0,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    if not os.path.exists('trained_models'):
        os.mkdir('trained_models')

    # Create the model

    train_htr = parser.train_htr == 'True'
    htr_gt_box = parser.htr_gt_box == 'True'
    torch.backends.cudnn.benchmark = False

    alphabet = dataset_train.alphabet
    if os.path.exists(parser.pretrained_model):
        retinanet = torch.load(parser.pretrained_model)
    else:
        if parser.depth == 18:
            retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                       pretrained=True,
                                       max_boxes=int(parser.max_boxes),
                                       score_threshold=float(
                                           parser.score_threshold),
                                       seg_level=parser.seg_level,
                                       alphabet=alphabet,
                                       train_htr=train_htr,
                                       htr_gt_box=htr_gt_box)

        elif parser.depth == 34:

            retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                       pretrained=True,
                                       max_boxes=int(parser.max_boxes),
                                       score_threshold=float(
                                           parser.score_threshold),
                                       seg_level=parser.seg_level,
                                       alphabet=alphabet,
                                       train_htr=train_htr,
                                       htr_gt_box=htr_gt_box)

        elif parser.depth == 50:
            retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                       pretrained=True)
        elif parser.depth == 101:
            retinanet = model.resnet101(
                num_classes=dataset_train.num_classes(), pretrained=True)
        elif parser.depth == 152:
            retinanet = model.resnet152(
                num_classes=dataset_train.num_classes(), pretrained=True)
        else:
            raise ValueError(
                'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True
    train_htr = parser.train_htr == 'True'
    train_det = parser.train_det == 'True'
    retinanet.htr_gt_box = parser.htr_gt_box == 'True'

    retinanet.train_htr = train_htr
    retinanet.epochs_only_det = parser.epochs_only_det

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=50,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)
    ctc = CTCLoss()
    retinanet.train()
    retinanet.module.freeze_bn()

    best_cer = 1000
    best_map = 0
    epochs_no_improvement = 0
    verbose_each = 1
    optimize_each = 1
    print(('Num training images: {}'.format(len(dataset_train))))

    for epoch_num in range(parser.epochs):
        cers = []

        retinanet.training = True

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            if iter_num > int(parser.max_iters_epoch): break
            try:
                if iter_num % optimize_each == 0:
                    optimizer.zero_grad()
                (classification_loss, regression_loss, ctc_loss,
                 ner_loss) = retinanet([
                     data['img'].cuda().float(), data['annot'], ctc, epoch_num
                 ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                if train_det:

                    if train_htr:
                        loss = ctc_loss + classification_loss + regression_loss + ner_loss

                    else:
                        loss = classification_loss + regression_loss

                elif train_htr:
                    loss = ctc_loss

                else:
                    continue
                if bool(loss == 0):
                    continue
                loss.backward()
                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)
                if iter_num % verbose_each == 0:
                    print((
                        'Epoch: {} | Step: {} |Classification loss: {:1.5f} | Regression loss: {:1.5f} | CTC loss: {:1.5f} | NER loss: {:1.5f} | Running loss: {:1.5f} | Total loss: {:1.5f}\r'
                        .format(epoch_num, iter_num,
                                float(classification_loss),
                                float(regression_loss), float(ctc_loss),
                                float(ner_loss), np.mean(loss_hist),
                                float(loss), "\r")))
                torch.cuda.empty_cache()

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))

            except Exception as e:
                print(e)
                continue
        if parser.dataset == 'csv' and parser.csv_val is not None and train_det:

            print('Evaluating dataset')

            mAP = csv_eval.evaluate(dataset_val,
                                    retinanet,
                                    score_threshold=parser.score_threshold)
            mAP = float(mAP[0][0])

        retinanet.eval()
        retinanet.training = False
        retinanet.score_threshold = float(parser.score_threshold)
        for idx, data in enumerate(dataloader_val):
            if idx > int(parser.max_iters_epoch): break
            print("Eval CER on validation set:", idx, "/", len(dataset_val),
                  "\r")
            image_name = dataset_val.image_names[idx].split('/')[-1].split(
                '.')[-2]

            #generate_pagexml(image_name,data,retinanet,parser.score_threshold,parser.nms_threshold,dataset_val)
            text_gt = dataset_val.image_names[idx].split('.')[0] + '.txt'
            f = open(text_gt, 'r')
            text_gt_lines = f.readlines()[0]
            transcript_pred = get_transcript(image_name, data, retinanet,
                                             float(parser.score_threshold),
                                             float(parser.nms_threshold),
                                             dataset_val, alphabet)
            cers.append(
                float(editdistance.eval(transcript_pred, text_gt_lines)) /
                len(text_gt_lines))

        t = str(time.time()).split('.')[0]

        valid_cer_f = open(experiment_id + '_valid_CER.txt', 'a')
        valid_cer_f.write(
            str(epoch_num) + " " + str(np.mean(cers)) + " " + t + '\n')
        valid_cer_f.close()
        print("GT", text_gt_lines)
        print("PREDS SAMPLE:", transcript_pred)

        if parser.early_stop_crit == 'cer':

            if float(np.mean(cers)) < float(best_cer):
                best_cer = np.mean(cers)
                epochs_no_improvement = 0
                torch.save(
                    retinanet.module, 'trained_models/' + parser.model_out +
                    '{}_retinanet.pt'.format(parser.dataset))
            else:
                epochs_no_improvement += 1
        elif parser.early_stop_crit == 'map':
            if mAP > best_map:
                best_map = mAP
                epochs_no_improvement = 0
                torch.save(
                    retinanet.module, 'trained_models/' + parser.model_out +
                    '{}_retinanet.pt'.format(parser.dataset))

            else:
                epochs_no_improvement += 1
        if train_det:
            print(epoch_num, "mAP: ", mAP, " best mAP", best_map)
        if train_htr:
            print("VALID CER:", np.mean(cers), "best CER", best_cer)
        print("Epochs no improvement:", epochs_no_improvement)
        if epochs_no_improvement > 3:
            for param_group in optimizer.param_groups:
                if param_group['lr'] > 10e-5:
                    param_group['lr'] *= 0.1

        if epochs_no_improvement >= parser.max_epochs_no_improvement:
            print("TRAINING FINISHED AT EPOCH", epoch_num, ".")
            sys.exit()

        scheduler.step(np.mean(epoch_loss))
        torch.cuda.empty_cache()

    retinanet.eval()
Ejemplo n.º 4
0
def main():
    args = parser.parse_args()
    cf = ConfigParser.ConfigParser()
    try:
        cf.read(args.conf)
    except:
        print("conf file not exists")

    logger = init_logger(os.path.join(args.log_dir, 'train_cnn_lstm_ctc.log'))
    dataset = cf.get('Data', 'dataset')
    data_dir = cf.get('Data', 'data_dir')
    feature_type = cf.get('Data', 'feature_type')
    out_type = cf.get('Data', 'out_type')
    n_feats = cf.getint('Data', 'n_feats')
    batch_size = cf.getint("Training", 'batch_size')

    #Data Loader
    train_dataset = myDataset(data_dir,
                              data_set='train',
                              feature_type=feature_type,
                              out_type=out_type,
                              n_feats=n_feats)
    train_loader = myCNNDataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4,
                                   pin_memory=False)
    dev_dataset = myDataset(data_dir,
                            data_set="test",
                            feature_type=feature_type,
                            out_type=out_type,
                            n_feats=n_feats)
    dev_loader = myCNNDataLoader(dev_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 pin_memory=False)

    #decoder for dev set
    decoder = GreedyDecoder(dev_dataset.int2phone, space_idx=-1, blank_index=0)

    #Define Model
    rnn_input_size = cf.getint('Model', 'rnn_input_size')
    rnn_hidden_size = cf.getint('Model', 'rnn_hidden_size')
    rnn_layers = cf.getint('Model', 'rnn_layers')
    rnn_type = RNN[cf.get('Model', 'rnn_type')]
    bidirectional = cf.getboolean('Model', 'bidirectional')
    batch_norm = cf.getboolean('Model', 'batch_norm')
    num_class = cf.getint('Model', 'num_class')
    drop_out = cf.getfloat('Model', 'num_class')
    model = CNN_LSTM_CTC(rnn_input_size=rnn_input_size,
                         rnn_hidden_size=rnn_hidden_size,
                         rnn_layers=rnn_layers,
                         rnn_type=rnn_type,
                         bidirectional=bidirectional,
                         batch_norm=batch_norm,
                         num_class=num_class,
                         drop_out=drop_out)
    #model.apply(xavier_uniform_init)
    print(model.name)

    #Training
    init_lr = cf.getfloat('Training', 'init_lr')
    num_epoches = cf.getint('Training', 'num_epoches')
    end_adjust_acc = cf.getfloat('Training', 'end_adjust_acc')
    decay = cf.getfloat("Training", 'lr_decay')
    weight_decay = cf.getfloat("Training", 'weight_decay')
    try:
        seed = cf.getint('Training', 'seed')
    except:
        seed = torch.cuda.initial_seed()
    params = {
        'num_epoches': num_epoches,
        'end_adjust_acc': end_adjust_acc,
        'seed': seed,
        'decay': decay,
        'learning_rate': init_lr,
        'weight_decay': weight_decay,
        'batch_size': batch_size,
        'feature_type': feature_type,
        'n_feats': n_feats,
        'out_type': out_type
    }

    if USE_CUDA:
        torch.cuda.manual_seed(seed)
        model = model.cuda()

    print(params)

    loss_fn = CTCLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 weight_decay=weight_decay)

    #visualization for training
    from visdom import Visdom
    viz = Visdom(env='863_corpus')
    title = dataset + ' ' + feature_type + str(n_feats) + ' CNN_LSTM_CTC'
    opts = [
        dict(title=title + " Loss", ylabel='Loss', xlabel='Epoch'),
        dict(title=title + " CER on Train", ylabel='CER', xlabel='Epoch'),
        dict(title=title + ' CER on DEV', ylabel='DEV CER', xlabel='Epoch')
    ]
    viz_window = [None, None, None]

    count = 0
    learning_rate = init_lr
    acc_best = -100
    acc_best_true = -100
    adjust_rate_flag = False
    stop_train = False
    adjust_time = 0
    start_time = time.time()
    loss_results = []
    training_cer_results = []
    dev_cer_results = []

    while not stop_train:
        if count >= num_epoches:
            break
        count += 1

        if adjust_rate_flag:
            learning_rate *= decay
            adjust_rate_flag = False
            for param in optimizer.param_groups:
                param['lr'] *= decay

        print("Start training epoch: %d, learning_rate: %.5f" %
              (count, learning_rate))
        logger.info("Start training epoch: %d, learning_rate: %.5f" %
                    (count, learning_rate))

        loss = train(model,
                     train_loader,
                     loss_fn,
                     optimizer,
                     logger,
                     print_every=20)
        loss_results.append(loss)
        cer = dev(model, train_loader, decoder, logger)
        print("cer on training set is %.4f" % cer)
        logger.info("cer on training set is %.4f" % cer)
        training_cer_results.append(cer)
        acc = dev(model, dev_loader, decoder, logger)
        dev_cer_results.append(acc)

        #model_path_accept = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'.pkl'
        #model_path_reject = './log/epoch'+str(count)+'_lr'+str(learning_rate)+'_cv'+str(acc)+'_rejected.pkl'

        if acc > (acc_best + end_adjust_acc):
            acc_best = acc
            adjust_rate_count = 0
            model_state = copy.deepcopy(model.state_dict())
            op_state = copy.deepcopy(optimizer.state_dict())
        elif (acc > acc_best - end_adjust_acc):
            adjust_rate_count += 1
            if acc > acc_best and acc > acc_best_true:
                acc_best_true = acc
                model_state = copy.deepcopy(model.state_dict())
                op_state = copy.deepcopy(optimizer.state_dict())
        else:
            adjust_rate_count = 0
        #torch.save(model.state_dict(), model_path_reject)
        print("adjust_rate_count:" + str(adjust_rate_count))
        print('adjust_time:' + str(adjust_time))
        logger.info("adjust_rate_count:" + str(adjust_rate_count))
        logger.info('adjust_time:' + str(adjust_time))

        if adjust_rate_count == 10:
            adjust_rate_flag = True
            adjust_time += 1
            adjust_rate_count = 0
            acc_best = acc_best_true
            model.load_state_dict(model_state)
            optimizer.load_state_dict(op_state)

        if adjust_time == 8:
            stop_train = True

        time_used = (time.time() - start_time) / 60
        print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
              (count, acc, time_used))
        logger.info("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
                    (count, acc, time_used))
        x_axis = range(count)
        y_axis = [
            loss_results[0:count], training_cer_results[0:count],
            dev_cer_results[0:count]
        ]
        for x in range(len(viz_window)):
            if viz_window[x] is None:
                viz_window[x] = viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    opts=opts[x],
                )
            else:
                viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    win=viz_window[x],
                    update='replace',
                )

    print("End training, best cv acc is: %.4f" % acc_best)
    logger.info("End training, best cv acc is: %.4f" % acc_best)
    best_path = os.path.join(args.log_dir,
                             'best_model' + '_cv' + str(acc_best) + '.pkl')
    cf.set('Model', 'model_file', best_path)
    cf.write(open(args.conf, 'w'))
    params['epoch'] = count
    torch.save(
        CNN_LSTM_CTC.save_package(model,
                                  optimizer=optimizer,
                                  epoch=params,
                                  loss_results=loss_results,
                                  training_cer_results=training_cer_results,
                                  dev_cer_results=dev_cer_results), best_path)
Ejemplo n.º 5
0
import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
print(cost)
Ejemplo n.º 6
0
def main(opts):
  # alphabet = '0123456789.'
  nclass = len(alphabet) + 1
  model_name = 'E2E-CRNN'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))

  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
  optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
  step_start = 0

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     print('loading pretrained model from %s' % opts.model)
  #     pretrained_model = OwnModel(attention=True, nclass=12)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
  #     model_dict.update(pretrained_dict)
  #     net.load_state_dict(model_dict)

  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)

  ## 数据集
  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate, num_workers=4)
  
  # 电表数据集
  # converter = strLabelConverter(alphabet)
  # dataset = ImgDataset(
  #     root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
  #     csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
  #     transform=None,
  #     target_transform=converter.encode
  # )
  # ocrdataloader = torch.utils.data.DataLoader(
  #     dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
  # )
  
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  for step in range(step_start, opts.max_iters):

    for index, date in enumerate(e2edataloader):
      im_data, gtso, lbso = date
      im_data = im_data.cuda()
       
      try:
    loss= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)

    net.zero_grad()
    # optimizer.zero_grad()
    loss.backward()
    optimizer.step()
      except:
    import sys, traceback
    traceback.print_exc(file=sys.stdout)
    pass


      if index % disp_interval == 0:
Ejemplo n.º 7
0
def main():
    args = parser.parse_args()
    save_folder = args.save_folder

    loss_results, cer_results, wer_results = None, None, None
    if args.visdom:
        from visdom import Visdom
        viz = Visdom()

        opts = [
            dict(title='Loss', ylabel='Loss', xlabel='Epoch'),
            dict(title='WER', ylabel='WER', xlabel='Epoch'),
            dict(title='CER', ylabel='CER', xlabel='Epoch')
        ]

        viz_windows = [None, None, None]
        loss_results, cer_results, wer_results = torch.Tensor(
            args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
        epochs = torch.arange(1, args.epochs + 1)

    try:
        os.makedirs(save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    criterion = CTCLoss()

    with open(args.labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True,
                                      augment=False)
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=True)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True)
    decoder = ArgMaxDecoder(labels)
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    if args.continue_from:
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from)
        model.load_state_dict(package['state_dict'])
        optimizer.load_state_dict(package['optim_dict'])
        start_epoch = int(package.get('epoch', None)
                          or 1) - 1  # Python index start at 0 for training
        start_iter = package.get('iteration', None)
        if start_iter is None:
            start_epoch += 1  # Assume that we saved a model after an epoch finished, so start at the next epoch.
            start_iter = 0
        else:
            start_iter += 1
        avg_loss = int(package.get('avg_loss', 0))
        if args.visdom and \
                        package['loss_results'] is not None and start_epoch > 0:  # Add previous scores to visdom graph
            epoch = start_epoch
            loss_results, cer_results, wer_results = package[
                'loss_results'], package['cer_results'], package['wer_results']
            x_axis = epochs[0:epoch]
            y_axis = [
                loss_results[0:epoch], wer_results[0:epoch],
                cer_results[0:epoch]
            ]
            for x in range(len(viz_windows)):
                viz_windows[x] = viz.line(
                    X=x_axis,
                    Y=y_axis[x],
                    opts=opts[x],
                )
    else:
        avg_loss = 0
        start_epoch = 0
        start_iter = 0

    print(model)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        end = time.time()
        for i, (data) in enumerate(train_loader, start=start_iter):
            if i == len(train_loader):
                break
            inputs, targets, input_percentages, target_sizes = data
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = Variable(inputs)
            target_sizes = Variable(target_sizes)
            targets = Variable(targets)

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH

            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int())

            loss = criterion(out, targets, sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            if args.cuda:
                torch.cuda.synchronize()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (i + 1),
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))
            if args.checkpoint_per_batch > 0 and i > 0 and (
                    i + 1) % args.checkpoint_per_batch == 0:
                file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (
                    save_folder, epoch + 1, i + 1)
                print("Saving checkpoint model to %s" % file_path)
                torch.save(
                    DeepSpeech.serialize(model,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         iteration=i,
                                         loss_results=loss_results,
                                         wer_results=wer_results,
                                         cer_results=cer_results,
                                         avg_loss=avg_loss), file_path)
        avg_loss /= len(train_loader)

        print('Training Summary Epoch: [{0}]\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, loss=avg_loss))

        start_iter = 0  # Reset start iteration for next epoch
        total_cer, total_wer = 0, 0
        model.eval()
        for i, (data) in enumerate(test_loader):  # test
            inputs, targets, input_percentages, target_sizes = data

            inputs = Variable(inputs)

            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH
            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int())

            decoded_output = decoder.decode(out.data, sizes)
            target_strings = decoder.process_strings(
                decoder.convert_to_strings(split_targets))
            wer, cer = 0, 0
            for x in range(len(target_strings)):
                wer += decoder.wer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x].split()))
                cer += decoder.cer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x]))
            total_cer += cer
            total_wer += wer

            if args.cuda:
                torch.cuda.synchronize()

        wer = total_wer / len(test_loader.dataset)
        cer = total_cer / len(test_loader.dataset)
        wer *= 100
        cer *= 100

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if args.visdom:
            loss_results[epoch] = avg_loss
            wer_results[epoch] = wer
            cer_results[epoch] = cer
            epoch += 1
            x_axis = epochs[0:epoch]
            y_axis = [
                loss_results[0:epoch], wer_results[0:epoch],
                cer_results[0:epoch]
            ]
            for x in range(len(viz_windows)):
                if viz_windows[x] is None:
                    viz_windows[x] = viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        opts=opts[x],
                    )
                else:
                    viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        win=viz_windows[x],
                        update='replace',
                    )
        if args.checkpoint:
            file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results), file_path)
        # anneal lr
        optim_state = optimizer.state_dict()
        optim_state['param_groups'][0][
            'lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
        optimizer.load_state_dict(optim_state)
        print('Learning rate annealed to: {lr:.6f}'.format(
            lr=optim_state['param_groups'][0]['lr']))

        avg_loss = 0
    torch.save(DeepSpeech.serialize(model, optimizer=optimizer),
               args.final_model_path)
Ejemplo n.º 8
0
    def forward(self, model, train_x, train_labels_flatten, train_img_seq_lens,
                train_label_lens, batch_size, test_x, test_seq_len, test_mask):
        with _disable_tracking_bn_stats(model):
            # TRAIN
            # calc adversarial direction
            # prepare random unit tensor
            train_d = torch.rand(train_x.shape).to(
                torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            train_d = _l2_normalize(train_d)
            for _ in range(self.ip):
                train_d.requires_grad_()
                train_loss_function = CTCLoss()
                train_preds = model.forward(train_x + self.xi * train_d,
                                            train_img_seq_lens)
                train_adv_loss_ctc = train_loss_function(
                    train_preds, train_labels_flatten,
                    Variable(torch.IntTensor(np.array(train_img_seq_lens))),
                    train_label_lens) / batch_size

                train_adv_loss_ctc.backward()
                train_d = train_d.grad
                model.zero_grad()

            #TEST
            with torch.no_grad():
                test_pred = model.vat_forward(test_x, test_seq_len)
                test_pred = test_pred * test_mask
                test_pred = F.softmax(test_pred,
                                      dim=2).view(-1,
                                                  test_pred.size()[-1])
                # prepare random unit tensor
            test_d = torch.rand(test_x.shape).to(
                torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            test_d = _l2_normalize(test_d)

            with _disable_tracking_bn_stats(model):
                # calc adversarial direction
                for _ in range(self.ip):
                    test_d.requires_grad_()

                    test_pred_hat = model.vat_forward(
                        test_x + self.xi * test_d, test_seq_len)
                    test_pred_hat = test_pred_hat * test_mask
                    test_pred_hat = F.log_softmax(test_pred_hat, dim=2).view(
                        -1,
                        test_pred_hat.size()[-1])

                    # pred_hat = model(x + self.xi * d)
                    # adv_distance = _kl_div(F.log_softmax(pred_hat, dim=1), pred)
                    test_adv_distance = _kl_div(test_pred_hat, test_pred)
                    test_adv_distance.backward()
                    test_d = _l2_normalize(test_d.grad)
                    model.zero_grad()

            #TRAIN
            # calc LDS
            train_r_adv = torch.sign(train_d) * self.eps

            train_pred_hat = model.forward(train_x + train_r_adv,
                                           train_img_seq_lens)
            train_lds = train_loss_function(
                train_pred_hat, train_labels_flatten,
                Variable(torch.IntTensor(np.array(train_img_seq_lens))),
                train_label_lens) / batch_size

            #TEST
            # calc LDS
            test_d = torch.sign(test_d)
            test_r_adv = test_d * self.eps

            test_pred_hat = model.vat_forward(test_x + test_r_adv,
                                              test_seq_len)
            test_pred_hat = test_pred_hat * test_mask
            test_pred_hat = F.log_softmax(test_pred_hat,
                                          dim=2).view(-1,
                                                      test_pred_hat.size()[-1])

            #pred_hat = model(x + r_adv)
            #lds = _kl_div(F.log_softmax(pred_hat, dim=1), pred)
            test_lds = _kl_div(test_pred_hat, test_pred)

        return train_lds, test_lds
Ejemplo n.º 9
0
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio))
print(len(train_loader))
#test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.resizeNormalize((100, 32)))
test_dataset = dataset.listDataset(list_file=opt.vallist)

alphabet = str1.decode('utf-8')
nclass = len(alphabet)
nc = 3

converterAttention = utils.strLabelConverterForAttention(alphabet, opt.sep)
converterCTC = utils.strLabelConverterForCTC(alphabet, opt.sep)
criterionAttention = torch.nn.CrossEntropyLoss()
criterionCTC = CTCLoss()


# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


#crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh, 24, 1, True)
#crnn.apply(weights_init)
Ejemplo n.º 10
0
def main(args=None):

	parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

	parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.')
	parser.add_argument('--coco_path', help='Path to COCO directory')
	parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
	parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
	parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')

	parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=18)
	parser.add_argument('--epochs', help='Number of epochs', type=int, default=100)

	parser = parser.parse_args(args)

	# Create the data loaders
	if parser.dataset == 'coco':

		if parser.coco_path is None:
			raise ValueError('Must provide --coco_path when training on COCO,')

		dataset_train = CocoDataset(parser.coco_path, set_name='train2017', transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
		dataset_val = CocoDataset(parser.coco_path, set_name='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))

	elif parser.dataset == 'csv':

		if parser.csv_train is None:
			raise ValueError('Must provide --csv_train when training on COCO,')

		if parser.csv_classes is None:
			raise ValueError('Must provide --csv_classes when training on COCO,')


		dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))

		if parser.csv_val is None:
			dataset_val = None
			print('No validation annotations provided.')
		else:
			dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))

	else:
		raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

	sampler = AspectRatioBasedSampler(dataset_train, batch_size=1,drop_last=False)
	dataloader_train = DataLoader(dataset_train, num_workers=3, collate_fn=collater, batch_sampler=sampler)

	if dataset_val is not None:
		sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
		dataloader_val = DataLoader(dataset_val, num_workers=3, collate_fn=collater, batch_sampler=sampler_val)

	# Create the model
	if parser.depth == 18:
		retinanet = model.resnet18(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 34:
		retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 50:
		retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 101:
		retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 152:
		retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
	else:
		raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')		

	use_gpu = True

	if use_gpu:
		retinanet = retinanet.cuda()
	
	retinanet = torch.nn.DataParallel(retinanet).cuda()
	
	#retinanet = torch.load('../Documents/TRAINED_MODELS/pytorch-retinanet/esposallescsv_retinanet_99.pt')
	#print "LOADED pretrained MODEL\n\n"
	
	retinanet.training = True

	optimizer = optim.Adam(retinanet.parameters(), lr=1e-4)

	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

	loss_hist = collections.deque(maxlen=500)
	ctc = CTCLoss()
	retinanet.train()
	retinanet.module.freeze_bn()

	print('Num training images: {}'.format(len(dataset_train)))
	for epoch_num in range(parser.epochs):

		retinanet.train()
		retinanet.module.freeze_bn()
		
		epoch_loss = []
		
		for iter_num, data in enumerate(dataloader_train):
			try:
				optimizer.zero_grad()

				(classification_loss, regression_loss,ctc_loss) = retinanet([data['img'].cuda().float(), data['annot'],ctc])

				classification_loss = classification_loss.mean()
				regression_loss = regression_loss.mean()	
				'''if regression_loss<0.5:		
					loss = classification_loss + regression_loss/2.+ctc_loss/4.
				else:
					loss = classification_loss +regression_loss'''
				loss = classification_loss + regression_loss+ctc_loss
				if bool(loss == 0):
					continue

				loss.backward()

				torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

				optimizer.step()

				loss_hist.append(float(loss))

				epoch_loss.append(float(loss))

				print('Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | CTC loss: {:1.5f} | Running loss: {:1.5f}'.format(epoch_num, iter_num, float(classification_loss), float(regression_loss),float(ctc_loss),np.mean(loss_hist)))
				
				del classification_loss
				del regression_loss
			except Exception as e:
				print(e)
				continue

		if parser.dataset == 'coco':

			print('Evaluating dataset')

			coco_eval.evaluate_coco(dataset_val, retinanet)

		'''elif parser.dataset == 'csv' and parser.csv_val is not None:

			print('Evaluating dataset')

			mAP = csv_eval.evaluate(dataset_val, retinanet)'''

		
		scheduler.step(np.mean(epoch_loss))	
		
		dataset_name = parser.csv_train.split("/")[-2]
		
		torch.save(retinanet.module, dataset_name+'{}_retinanet_{}.pt'.format(parser.dataset, epoch_num))

	retinanet.eval()

	torch.save(retinanet, 'model_final.pt'.format(epoch_num))
Ejemplo n.º 11
0
# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
crnn = GRCNN.GRCNN(args.n_class)
crnn.apply(weights_init)



crnn = crnn.to(device)
criterion = CTCLoss().to(device)

#net.cuda()
print('net has load!')
converter = utils.strLabelConverter(args.alphabet)

optimizer=optim.Adam(crnn.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

best_acc=-1
totalLoss=[]
avg_test_acc = []
avg_train_acc = []

def get_img(img_path):
    img = cv2.imread(img_path)
    img= cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
Ejemplo n.º 12
0
def main(opt):
    print(opt)

    if opt.experiment is None:
        opt.experiment = 'expr'

    os.system('mkdir {0}'.format(opt.experiment))

    # Why is this?
    opt.manualSeed = random.randint(1, 10000)  # fix seed

    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    train_dataset = dataset.hwrDataset(mode="train")
    assert train_dataset
    # if not opt.random_sample:
    #     sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
    # else:
    #     sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=int(opt.workers),
                                               collate_fn=dataset.alignCollate(
                                                   imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio=True))
    # test_dataset = dataset.lmdbDataset(
    #     root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

    test_dataset = dataset.hwrDataset(mode="test",
                                      transform=dataset.resizeNormalize(
                                          (100, 32)))

    nclass = len(opt.alphabet) + 1
    nc = 1

    criterion = CTCLoss()

    # custom weights initialization called on crnn
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    crnn = crnn_model.CRNN(opt.imgH, nc, nclass, opt.nh)
    crnn.apply(weights_init)
    if opt.crnn != '':
        print('loading pretrained model from %s' % opt.crnn)
        crnn.load_state_dict(torch.load(opt.crnn))
    print(crnn)

    # TODO make this central

    image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
    text = torch.IntTensor(opt.batchSize * 5)
    length = torch.IntTensor(opt.batchSize)

    if opt.cuda:
        crnn.cuda()
        crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
        image = image.cuda()
        criterion = criterion.cuda()

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # TODO what is this, read this.
    # loss averager
    loss_avg = utils.averager()

    # Todo default is RMS Prop. I wonder why?
    # setup optimizer

    #Following the paper's recommendation

    opt.adadelta = True
    if opt.adam:
        optimizer = optim.Adam(crnn.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    elif opt.adadelta:
        optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)

    converter = utils.strLabelConverter(opt.alphabet)

    def val(net, dataset, criterion, max_iter=100):
        print('Start val')

        for p in crnn.parameters():
            p.requires_grad = False

        net.eval()
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  shuffle=True,
                                                  batch_size=opt.batchSize,
                                                  num_workers=int(opt.workers))
        val_iter = iter(data_loader)

        n_correct = 0
        loss_avg = utils.averager()

        max_iter = min(max_iter, len(data_loader))
        for i in range(max_iter):
            print("Is 'i' jumping two values? i == " + str(i))
            data = val_iter.next()
            i += 1
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)

            preds = crnn(image)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg.add(cost)

            _, preds = preds.max(
                2
            )  # todo where is the output size set to 26? Empirically it is.
            # preds = preds.squeeze(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data,
                                         preds_size.data,
                                         raw=False)  # Todo read this.
            for pred, target in zip(sim_preds, cpu_texts):
                if pred == target.lower():
                    n_correct += 1

        raw_preds = converter.decode(preds.data, preds_size.data,
                                     raw=True)[:opt.n_test_disp]
        for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
            print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(max_iter * opt.batchSize)
        print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    for epoch in range(opt.niter):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = train_batch(crnn, criterion, optimizer, train_iter, opt,
                               converter)
            loss_avg.add(cost)
            i += 1

            if i % opt.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, opt.niter, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % opt.valInterval == 0:
                try:
                    val(crnn, test_dataset, criterion)
                except Exception as e:
                    print(e)

            # do checkpointing
            if i % opt.saveInterval == 0:
                torch.save(
                    crnn.state_dict(),
                    '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '	')
    log.close()
    """ model configuration """
    # if 'CTC' in opt.Prediction:
    if opt.baiduCTC:
        CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character)
    else:
        CTC_converter = CTCLabelConverter(opt.character)


# else:
    Attn_converter = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(CTC_converter.character)
    opt.num_class_attn = len(Attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        # print(name)
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    print("Model:")
    print(model)
    # print(summary(model, (1, opt.imgH, opt.imgW,1)))
    """ setup loss """
    if opt.baiduCTC:
        # need to install warpctc. see our guideline.
        if opt.label_smooth:
            criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc,
                                                 weight=0.05)
        else:
            criterion_major_path = CTCLoss()
        #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0)
    else:
        criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    # else:
    #     criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
    loss_avg_major_path = Averager()
    loss_avg_guide_path = Averager()
    # filter that only require gradient decent
    guide_parameters = []
    major_parameters = []
    guide_model_part_names = [
        "Transformation", "FeatureExtraction", "SequenceModeling_Attn",
        "Attention"
    ]
    major_model_part_names = ["SequenceModeling_CTC", "CTC"]
    for name, param in model.named_parameters():
        if param.requires_grad:
            if name.split(".")[1] in guide_model_part_names:
                guide_parameters.append(param)
            elif name.split(".")[1] in major_model_part_names:
                major_parameters.append(param)
            # print(name)
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
    if opt.continue_training:
        guide_parameters = []
    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer_ctc = AdamW(major_parameters, lr=opt.lr)
        if not opt.continue_training:
            optimizer_attn = AdamW(guide_parameters, lr=opt.lr)
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter)
    start_iter = 0
    if opt.saved_model != '' and (not opt.continue_training):
        print(f'loading pretrained model from {opt.saved_model}')
        checkpoint = torch.load(opt.saved_model)
        start_iter = checkpoint['start_iter'] + 1
        if not opt.adam:
            optimizer_ctc.load_state_dict(
                checkpoint['optimizer_ctc_state_dict'])
            if not opt.continue_training:
                optimizer_attn.load_state_dict(
                    checkpoint['optimizer_attn_state_dict'])
            scheduler_ctc.load_state_dict(
                checkpoint['scheduler_ctc_state_dict'])
            scheduler_attn.load_state_dict(
                checkpoint['scheduler_attn_state_dict'])
            print(scheduler_ctc.get_lr())
            print(scheduler_attn.get_lr())
        if opt.FT:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
    if opt.continue_training:
        model.load_state_dict(torch.load(opt.saved_model))
    # print("Optimizer:")
    # print(optimizer)
    #
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------	'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}	'
        opt_log += '---------------------------------------	'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter - 1
    if opt.continue_training:
        start_iter = 0
    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        iteration += 1
        if iteration < start_iter:
            continue
        image = image_tensors.to(device)
        # print(image.size())
        text_attn, length_attn = Attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("1")
        text_ctc, length_ctc = CTC_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("2")
        #if iteration == start_iter :
        #    writer.add_graph(model, (image, text_attn))
        batch_size = image.size(0)
        preds_major, preds_guide = model(image, text_attn[:, :-1])
        #print("10")
        preds_size = torch.IntTensor([preds_major.size(1)] * batch_size)
        if opt.baiduCTC:
            preds_major = preds_major.permute(1, 0, 2)  # to use CTCLoss format
            if opt.label_smooth:
                cost_ctc = criterion_major_path(preds_major, text_ctc,
                                                preds_size, length_ctc,
                                                batch_size)
            else:
                cost_ctc = criterion_major_path(
                    preds_major, text_ctc, preds_size, length_ctc) / batch_size
        else:
            preds_major = preds_major.log_softmax(2).permute(1, 0, 2)
            cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size,
                                            length_ctc)
        #print("3")
        # preds = model(image, text[:, :-1])  # align with Attention.forward
        target = text_attn[:, 1:]  # without [GO] Symbol
        if not opt.continue_training:
            cost_attn = criterion_guide_path(
                preds_guide.view(-1, preds_guide.shape[-1]),
                target.contiguous().view(-1))
            optimizer_attn.zero_grad()
            cost_attn.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(
                guide_parameters,
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer_attn.step()
        optimizer_ctc.zero_grad()
        cost_ctc.backward()
        torch.nn.utils.clip_grad_norm_(
            major_parameters,
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer_ctc.step()
        scheduler_ctc.step()
        scheduler_attn.step()
        #print("4")
        loss_avg_major_path.add(cost_ctc)
        if not opt.continue_training:
            loss_avg_guide_path.add(cost_attn)
        if (iteration + 1) % 100 == 0:
            writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(),
                              (iteration + 1) // 100)
            loss_avg_major_path.reset()
            if not opt.continue_training:
                writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(),
                                  (iteration + 1) // 100)
                loss_avg_guide_path.reset()
        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0:  #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_major_path, valid_loader,
                        CTC_converter, opt)
                model.train()
                writer.add_scalar("Loss/valid", valid_loss,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/accuracy", current_accuracy,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/norm_ED", current_norm_ED,
                                  (iteration + 1) // opt.valInterval)
                # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                # loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                # training loss and validation loss
                if not opt.continue_training:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                else:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg_major_path.reset()
                if not opt.continue_training:
                    loss_avg_guide_path.reset()
                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}	{current_model_log}	{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '	')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}	{head}	{dashed_line}	'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    # if 'Attn' in opt.Prediction:
                    #     gt = gt[:gt.find('[s]')]
                    #     pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}	{str(pred == gt)}	'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '	')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training):
            # print(scheduler_ctc.get_lr())
            # print(scheduler_attn.get_lr())
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'optimizer_attn_state_dict': optimizer_attn.state_dict(),
                    'optimizer_ctc_state_dict': optimizer_ctc.state_dict(),
                    'start_iter': iteration,
                    'scheduler_ctc_state_dict': scheduler_ctc.state_dict(),
                    'scheduler_attn_state_dict': scheduler_attn.state_dict(),
                }, f'{fol_ckpt}/current_model.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
Ejemplo n.º 14
0
def main():
    args = parser.parse_args()
    save_folder = args.save_folder

    loss_results, valid_results, cer_results, wer_results, lr_results = torch.Tensor(
        args.epochs), torch.Tensor(args.epochs), torch.Tensor(
            args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
    best_cer = None
    if args.visdom:
        from visdom import Visdom
        viz = Visdom()

        opts = [
            dict(title=args.visdom_id + ' Loss', ylabel='Loss',
                 xlabel='Epoch'),
            dict(title=args.visdom_id + ' WER', ylabel='WER', xlabel='Epoch'),
            dict(title=args.visdom_id + ' CER', ylabel='CER', xlabel='Epoch')
        ]

        viz_windows = [None, None, None]
        epochs = torch.arange(1, args.epochs + 1)
    if args.tensorboard:
        from logger import TensorBoardLogger
        try:
            os.makedirs(args.log_dir)
        except OSError as e:
            if e.errno == errno.EEXIST:
                print('Directory already exists.')
                for file in os.listdir(args.log_dir):
                    file_path = os.path.join(args.log_dir, file)
                    try:
                        if os.path.isfile(file_path):
                            os.unlink(file_path)
                    except Exception as e:
                        raise
            else:
                raise
        logger = TensorBoardLogger(args.log_dir)

    try:
        os.makedirs(save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    criterion = CTCLoss()

    with open(args.labels_path) as label_file:
        labels = json.load(label_file)
    print("labels:", labels)
    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window,
                      noise_dir=args.noise_dir,
                      noise_prob=args.noise_prob,
                      noise_levels=(args.noise_min, args.noise_max))

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True,
                                       augment=args.augment)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True,
                                      augment=False)
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=True)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True)
    decoder = GreedyDecoder(labels,
                            space_index=labels.index('<space>'),
                            blank_index=labels.index('_'))

    if args.continue_from:
        print("Loading checkpoint model %s" % args.continue_from)
        package = torch.load(args.continue_from)
        model.load_state_dict(package['state_dict'])
        optimizer.load_state_dict(package['optim_dict'])
        start_epoch = int(package.get(
            'epoch', 1)) - 1  # Python index start at 0 for training
        start_iter = package.get('iteration', None)
        if start_iter is None:
            start_epoch += 1  # Assume that we saved a model after an epoch finished, so start at the next epoch.
            start_iter = 0
        else:
            start_iter += 1
        avg_loss = int(package.get('avg_loss', 0))
        loss_results, valid_results, cer_results, wer_results, lr_results = package[
            'loss_results'], package['valid_results'], package[
                'cer_results'], package['wer_results'], package['lr_results']
        if args.visdom and \
                        package['loss_results'] is not None and start_epoch > 0:  # Add previous scores to visdom graph
            x_axis = epochs[0:start_epoch]
            y_axis = [
                loss_results[0:start_epoch], valid_results[0:start_epoch],
                wer_results[0:start_epoch], cer_results[0:start_epoch],
                lr_results[0:epoch]
            ]
            for x in range(len(viz_windows)):
                viz_windows[x] = viz.line(
                    X=x_axis,
                    Y=y_axis[x],
                    opts=opts[x],
                )
        if args.tensorboard and \
                        package['loss_results'] is not None and start_epoch > 0:  # Previous scores to tensorboard logs
            for i in range(start_epoch):
                info = {
                    'Train Loss': loss_results[i],
                    'Validation Loss': valid_results[i],
                    'Avg WER': wer_results[i],
                    'Avg CER': cer_results[i],
                    'Learning rate': lr_results[i]
                }
                for tag, val in info.items():
                    logger.scalar_summary(tag, val, i + 1)
        if not args.no_bucketing:
            print("Using bucketing sampler for the following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler
    else:
        avg_loss = 0
        start_epoch = 0
        start_iter = 0
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    print(start_epoch, args.epochs)
    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        model.train()
        end = time.time()
        avg_loss = 0
        for i, (data) in enumerate(train_loader, start=start_iter):
            if i == len(train_loader):
                break
            inputs, targets, input_percentages, target_sizes = data
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = Variable(inputs, requires_grad=False)
            target_sizes = Variable(target_sizes, requires_grad=False)
            targets = Variable(targets, requires_grad=False)

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH

            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int(),
                             requires_grad=False)

            loss = criterion(out, targets, sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            if args.cuda:
                torch.cuda.synchronize()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (i + 1),
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))
            if args.checkpoint_per_batch > 0 and i > 0 and (
                    i + 1) % args.checkpoint_per_batch == 0:
                file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth.tar' % (
                    save_folder, epoch + 1, i + 1)
                print("Saving checkpoint model to %s" % file_path)
                torch.save(
                    DeepSpeech.serialize(model,
                                         optimizer=optimizer,
                                         epoch=epoch,
                                         iteration=i,
                                         loss_results=loss_results,
                                         valid_results=valid_results,
                                         wer_results=wer_results,
                                         cer_results=cer_results,
                                         lr_results=lr_results,
                                         avg_loss=avg_loss), file_path)
            del loss
            del out
        avg_loss /= len(train_loader)

        print('Training Summary Epoch: [{0}]\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, loss=avg_loss))

        start_iter = 0  # Reset start iteration for next epoch
        total_cer, total_wer = 0, 0
        model.eval()
        valid_loss = 0
        for i, (data) in enumerate(test_loader):  # test
            inputs, targets, input_percentages, target_sizes = data

            inputs = Variable(inputs, volatile=True)

            # for cur_target in targets:
            #     print cur_target
            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH
            seq_length = out.size(0)
            sizes = input_percentages.mul_(int(seq_length)).int()

            # calcluate validation loss
            target_sizes = Variable(target_sizes, requires_grad=False)
            targets = Variable(targets, requires_grad=False)
            loss = criterion(out, targets, Variable(sizes,
                                                    requires_grad=False),
                             target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch
            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            valid_loss += loss_value
            decoded_output = decoder.decode(out.data, sizes)
            target_strings = decoder.process_strings(
                decoder.convert_to_strings(split_targets))
            # print("out.data",out.data[0])
            # print("decoded_output[0]",decoded_output[0])
            # print("split_targets",split_targets[0])
            # print("target_strings[0]",target_strings[0])
            wer, cer = 0, 0
            for x in range(len(target_strings)):
                wer += decoder.wer(
                    decoded_output[x], target_strings[x]) / float(
                        len(target_strings[x].replace(' ',
                                                      '').split('<space>')))
                cer += decoder.cer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x].split(' ')))
            total_cer += cer
            total_wer += wer

            if args.cuda:
                torch.cuda.synchronize()
            del out
        valid_loss /= len(test_loader)
        wer = total_wer / len(test_loader.dataset)
        cer = total_cer / len(test_loader.dataset)
        optim_state = optimizer.state_dict()
        wer *= 100
        cer *= 100
        loss_results[epoch] = avg_loss
        valid_results[epoch] = valid_loss
        wer_results[epoch] = wer
        cer_results[epoch] = cer
        lr_results[epoch] = optim_state['param_groups'][0]['lr']
        print('Validation Summary Epoch: [{0}]\t'
              'Validation loss {valid_loss:.3f}\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'
              'Learning rate{learning_rate:.6f}'.format(
                  epoch + 1,
                  valid_loss=valid_loss,
                  wer=wer,
                  cer=cer,
                  learning_rate=optim_state['param_groups'][0]['lr']))

        if args.visdom:
            # epoch += 1
            x_axis = epochs[0:epoch + 1]
            y_axis = [
                loss_results[0:epoch + 1], wer_results[0:epoch + 1],
                cer_results[0:epoch + 1]
            ]
            for x in range(len(viz_windows)):
                if viz_windows[x] is None:
                    viz_windows[x] = viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        opts=opts[x],
                    )
                else:
                    viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        win=viz_windows[x],
                        update='replace',
                    )
        if args.tensorboard:
            optim_state = optimizer.state_dict()
            learning_rate = optim_state['param_groups'][0]['lr']
            info = {
                'Train Loss': avg_loss,
                'Validation Loss': valid_loss,
                'Avg WER': wer,
                'Avg CER': cer,
                'Learning rate': learning_rate
            }
            for tag, val in info.items():
                logger.scalar_summary(tag, val, epoch + 1)
            if args.log_params:
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    logger.histo_summary(tag, to_np(value), epoch + 1)
                    logger.histo_summary(tag + '/grad', to_np(value.grad),
                                         epoch + 1)
        if args.checkpoint:
            file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     valid_results=valid_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results,
                                     lr_results=lr_results), file_path)
        # anneal lr
        optim_state = optimizer.state_dict()
        optim_state['param_groups'][0][
            'lr'] = optim_state['param_groups'][0]['lr'] / args.learning_anneal
        optimizer.load_state_dict(optim_state)
        print('Learning rate annealed to: {lr:.6f}'.format(
            lr=optim_state['param_groups'][0]['lr']))
        print("training used time: %.2f " % (time.time() - start_time))
        if best_cer is None or best_cer > cer:
            print("Found better validated model, saving to %s" %
                  args.model_path)
            torch.save(
                DeepSpeech.serialize(model,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     loss_results=loss_results,
                                     valid_results=valid_results,
                                     wer_results=wer_results,
                                     cer_results=cer_results,
                                     lr_results=lr_results), args.model_path)
            best_cer = cer

        avg_loss = 0
        if not args.no_bucketing and epoch == 0:
            print("Switching to bucketing sampler for following epochs")
            train_dataset = SpectrogramDatasetWithLength(
                audio_conf=audio_conf,
                manifest_filepath=args.train_manifest,
                labels=labels,
                normalize=True,
                augment=args.augment)
            sampler = BucketingSampler(train_dataset)
            train_loader.sampler = sampler
Ejemplo n.º 15
0
    def __init__(self,
                 rank,
                 size,
                 master='gpu10',
                 dist=True,
                 ngpus=1,
                 batch_size=32,
                 is_weak_scaling=True,
                 data_dir='./data',
                 dataset='cifar10',
                 dnn='resnet20',
                 lr=0.04,
                 nworkers=1,
                 prefix=None,
                 sparsity=0.95,
                 pretrain=None,
                 num_steps=35,
                 tb_writer=None,
                 amp_handle=None):

        self.size = size
        self.rank = rank
        self.pretrain = pretrain
        self.dataset = dataset
        self.prefix = prefix
        self.num_steps = num_steps
        self.ngpus = ngpus
        self.writer = tb_writer
        self.amp_handle = amp_handle
        if self.ngpus > 0:
            self.batch_size = batch_size * self.ngpus if is_weak_scaling else batch_size
        else:
            self.batch_size = batch_size
        self.num_batches_per_epoch = -1
        if self.dataset == 'cifar10' or self.dataset == 'mnist':
            self.num_classes = 10
        elif self.dataset == 'imagenet':
            self.num_classes = 1000
        elif self.dataset == 'an4':
            self.num_classes = 29
        elif self.dataset == 'ptb':
            self.num_classes = 10
        self.nworkers = nworkers  # just for easy comparison
        self.data_dir = data_dir
        if type(dnn) != str:
            self.net = dnn
            self.dnn = dnn.name
            self.ext = None  # leave for further parameters
        else:
            self.dnn = dnn
            # TODO: Refact these codes!
            if self.dnn == 'lstm':
                if data_dir is not None:
                    self.data_prepare()
                self.net, self.ext = create_net(self.num_classes,
                                                self.dnn,
                                                vocab_size=self.vocab_size,
                                                batch_size=self.batch_size)
            elif self.dnn == 'lstman4':
                self.net, self.ext = create_net(self.num_classes,
                                                self.dnn,
                                                datapath=self.data_dir)
                if data_dir is not None:
                    self.data_prepare()
            else:
                if data_dir is not None:
                    self.data_prepare()
                self.net, self.ext = create_net(self.num_classes, self.dnn)
        self.lr = lr
        self.base_lr = self.lr
        self.is_cuda = self.ngpus > 0

        if self.is_cuda:
            if self.ngpus > 1:
                devices = get_available_gpu_device_ids(ngpus)
                self.net = torch.nn.DataParallel(self.net,
                                                 device_ids=devices).cuda()
            else:
                self.net.cuda()
        self.net.share_memory()
        self.accuracy = 0
        self.loss = 0.0
        self.train_iter = 0
        self.recved_counter = 0
        self.master = master
        self.average_iter = 0
        if self.dataset != 'an4':
            if self.is_cuda:
                self.criterion = nn.CrossEntropyLoss().cuda()
            else:
                self.criterion = nn.CrossEntropyLoss()
        else:
            from warpctc_pytorch import CTCLoss
            self.criterion = CTCLoss()
        weight_decay = 1e-4
        self.m = 0.9  # momentum
        nesterov = False
        if self.dataset == 'an4':
            #nesterov = True
            self.lstman4_lr_epoch_tag = 0
            #weight_decay = 0.
        elif self.dataset == 'ptb':
            self.m = 0
            weight_decay = 0
        elif self.dataset == 'imagenet':
            #weight_decay = 5e-4
            self.m = 0.875
            weight_decay = 2 * 3.0517578125e-05

        decay = []
        no_decay = []
        for name, param in self.net.named_parameters():
            if not param.requires_grad:
                continue
            if len(param.shape) == 1 or 'bn' in name or 'bias' in name:
                no_decay.append(param)
            else:
                decay.append(param)
        parameters = [{
            'params': no_decay,
            'weight_decay': 0.
        }, {
            'params': decay,
            'weight_decay': weight_decay
        }]

        self.optimizer = optim.SGD(parameters,
                                   lr=self.lr,
                                   weight_decay=weight_decay)

        self.train_epoch = 0

        if self.pretrain is not None and os.path.isfile(self.pretrain):
            self.load_model_from_file(self.pretrain)

        self.sparsities = []
        self.compression_ratios = []
        self.communication_sizes = []
        self.remainer = {}
        self.v = {}
        self.sparsity = sparsity
        self.avg_loss_per_epoch = 0.0
        self.timer = 0.0
        self.forwardtime = 0.0
        self.backwardtime = 0.0
        self.iotime = 0.0
        self.epochs_info = []
        self.distributions = {}
        self.gpu_caches = {}
        self.delays = []
        self.num_of_updates_during_comm = 0
        self.train_acc_top1 = []
        if apex is not None:
            self.init_fp16()
        logger.info('num_batches_per_epoch: %d' % self.num_batches_per_epoch)
Ejemplo n.º 16
0
def main():
    args = parser.parse_args()
    cf = ConfigParser.ConfigParser()
    try:
        cf.read(args.conf)
    except:
        print("conf file not exists")
        sys.exit(1)
    USE_CUDA = cf.getboolean('Training', 'use_cuda')
    try:
        seed = long(cf.get('Training', 'seed'))
    except:
        seed = torch.cuda.initial_seed()
        cf.set('Training', 'seed', seed)
        cf.write(open(args.conf, 'w'))

    torch.manual_seed(seed)
    if USE_CUDA:
        torch.cuda.manual_seed(seed)

    log_dir = cf.get('Data', 'log_dir')
    log_file = os.path.join(log_dir, cf.get('Data', 'log_file'))
    logger = init_logger(log_file)

    #Define Model
    rnn_input_size = cf.getint('Model', 'rnn_input_size')
    rnn_hidden_size = cf.getint('Model', 'rnn_hidden_size')
    rnn_layers = cf.getint('Model', 'rnn_layers')
    rnn_type = RNN[cf.get('Model', 'rnn_type')]
    bidirectional = cf.getboolean('Model', 'bidirectional')
    batch_norm = cf.getboolean('Model', 'batch_norm')
    rnn_param = {
        "rnn_input_size": rnn_input_size,
        "rnn_hidden_size": rnn_hidden_size,
        "rnn_layers": rnn_layers,
        "rnn_type": rnn_type,
        "bidirectional": bidirectional,
        "batch_norm": batch_norm
    }
    num_class = cf.getint('Model', 'num_class')
    drop_out = cf.getfloat('Model', 'drop_out')

    model = CTC_Model(rnn_param=rnn_param,
                      num_class=num_class,
                      drop_out=drop_out)
    print("Model Structure:")
    logger.info("Model Structure:")
    for idx, m in enumerate(model.children()):
        print(idx, m)
        logger.info(str(idx) + "->" + str(m))

    data_dir = cf.get('Data', 'data_dir')
    batch_size = cf.getint("Training", 'batch_size')

    #Data Loader
    train_dataset = SpeechDataset(data_dir, data_set='train')
    dev_dataset = SpeechDataset(data_dir, data_set="dev")
    train_loader = SpeechDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=False)
    dev_loader = SpeechDataLoader(dev_dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=4,
                                  pin_memory=False)

    #ensure the feats is equal to the rnn_input_Size
    assert train_dataset.n_feats == rnn_input_size

    #decoder for dev set
    decoder = GreedyDecoder(int2char,
                            space_idx=len(int2char) - 1,
                            blank_index=0)

    #Training
    init_lr = cf.getfloat('Training', 'init_lr')
    num_epoches = cf.getint('Training', 'num_epoches')
    end_adjust_acc = cf.getfloat('Training', 'end_adjust_acc')
    decay = cf.getfloat("Training", 'lr_decay')
    weight_decay = cf.getfloat("Training", 'weight_decay')

    params = {
        'num_epoches': num_epoches,
        'end_adjust_acc': end_adjust_acc,
        'seed': seed,
        'decay': decay,
        'learning_rate': init_lr,
        'weight_decay': weight_decay,
        'batch_size': batch_size,
        'n_feats': train_dataset.n_feats
    }
    print(params)

    if USE_CUDA:
        model = model.cuda()

    loss_fn = CTCLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 weight_decay=weight_decay)

    #visualization for training
    from visdom import Visdom
    viz = Visdom()
    title = 'TIMIT LSTM_CTC Acoustic Model'

    opts = [
        dict(title=title + " Loss", ylabel='Loss', xlabel='Epoch'),
        dict(title=title + " Loss on Dev", ylabel='DEV Loss', xlabel='Epoch'),
        dict(title=title + ' CER on DEV', ylabel='DEV CER', xlabel='Epoch')
    ]
    viz_window = [None, None, None]

    count = 0
    learning_rate = init_lr
    loss_best = 1000
    loss_best_true = 1000
    adjust_rate_flag = False
    stop_train = False
    adjust_time = 0
    acc_best = 0
    start_time = time.time()
    loss_results = []
    dev_loss_results = []
    dev_cer_results = []

    while not stop_train:
        if count >= num_epoches:
            break
        count += 1

        if adjust_rate_flag:
            learning_rate *= decay
            adjust_rate_flag = False
            for param in optimizer.param_groups:
                param['lr'] *= decay

        print("Start training epoch: %d, learning_rate: %.5f" %
              (count, learning_rate))
        logger.info("Start training epoch: %d, learning_rate: %.5f" %
                    (count, learning_rate))

        loss = train(model,
                     train_loader,
                     loss_fn,
                     optimizer,
                     logger,
                     print_every=20,
                     USE_CUDA=USE_CUDA)
        loss_results.append(loss)
        acc, dev_loss = dev(model,
                            dev_loader,
                            loss_fn,
                            decoder,
                            logger,
                            USE_CUDA=USE_CUDA)
        print("loss on dev set is %.4f" % dev_loss)
        logger.info("loss on dev set is %.4f" % dev_loss)
        dev_loss_results.append(dev_loss)
        dev_cer_results.append(acc)

        #adjust learning rate by dev_loss
        #adjust_rate_count  :  表示连续超过count个epoch的loss在end_adjust_acc区间内认为稳定
        if dev_loss < (loss_best - end_adjust_acc):
            loss_best = dev_loss
            loss_best_true = dev_loss
            adjust_rate_count = 0
            acc_best = acc
            best_model_state = copy.deepcopy(model.state_dict())
            best_op_state = copy.deepcopy(optimizer.state_dict())
        elif (dev_loss < loss_best + end_adjust_acc):
            adjust_rate_count += 1
            if dev_loss < loss_best and dev_loss < loss_best_true:
                loss_best_true = dev_loss
                acc_best = acc
                best_model_state = copy.deepcopy(model.state_dict())
                best_op_state = copy.deepcopy(optimizer.state_dict())
        else:
            adjust_rate_count = 10

        print("adjust_rate_count: %d" % adjust_rate_count)
        print('adjust_time: %d' % adjust_time)
        logger.info("adjust_rate_count: %d" % adjust_rate_count)
        logger.info('adjust_time: %d' % adjust_time)

        if adjust_rate_count == 10:
            adjust_rate_flag = True
            adjust_time += 1
            adjust_rate_count = 0
            if loss_best > loss_best_true:
                loss_best = loss_best_true
            model.load_state_dict(best_model_state)
            optimizer.load_state_dict(best_op_state)

        if adjust_time == 8:
            stop_train = True

        time_used = (time.time() - start_time) / 60
        print("epoch %d done, dev acc is: %.4f, time_used: %.4f minutes" %
              (count, acc, time_used))
        logger.info(
            "epoch %d done, dev acc is: %.4f, time_used: %.4f minutes" %
            (count, acc, time_used))

        x_axis = range(count)
        y_axis = [
            loss_results[0:count], dev_loss_results[0:count],
            dev_cer_results[0:count]
        ]
        for x in range(len(viz_window)):
            if viz_window[x] is None:
                viz_window[x] = viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    opts=opts[x],
                )
            else:
                viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    win=viz_window[x],
                    update='replace',
                )

    print("End training, best dev loss is: %.4f, acc is: %.4f" %
          (loss_best_true, acc_best))
    logger.info("End training, best dev loss acc is: %.4f, acc is: %.4f" %
                (loss_best_true, acc_best))
    model.load_state_dict(best_model_state)
    optimizer.load_state_dict(best_op_state)
    best_path = os.path.join(log_dir,
                             'best_model' + '_dev' + str(acc_best) + '.pkl')
    cf.set('Model', 'model_file', best_path)
    cf.write(open(args.conf, 'w'))
    params['epoch'] = count

    torch.save(
        CTC_Model.save_package(model,
                               optimizer=optimizer,
                               epoch=params,
                               loss_results=loss_results,
                               dev_loss_results=dev_loss_results,
                               dev_cer_results=dev_cer_results), best_path)
Ejemplo n.º 17
0
def train(cfg):
    # Set seeds for determinism
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    random.seed(cfg.training.seed)

    main_proc = True
    device = torch.device("cpu" if cfg.training.no_cuda else "cuda")

    is_distributed = os.environ.get("LOCAL_RANK")  # If local rank exists, distributed env

    if is_distributed:
        # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops
        # because NCCL uses a spin-lock on the device. Set this env var and
        # to enable a watchdog thread that will destroy stale NCCL communicators
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=cfg.training.dist_backend.value)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    if OmegaConf.get_type(cfg.checkpointing) == FileCheckpointConfig:
        checkpoint_handler = FileCheckpointHandler(cfg=cfg.checkpointing)
    elif OmegaConf.get_type(cfg.checkpointing) == GCSCheckpointConfig:
        checkpoint_handler = GCSCheckpointHandler(cfg=cfg.checkpointing)
    else:
        raise ValueError("Checkpoint Config has not been specified correctly.")

    if main_proc and cfg.visualization.visdom:
        visdom_logger = VisdomLogger(id=cfg.visualization.id,
                                     num_epochs=cfg.training.epochs)
    if main_proc and cfg.visualization.tensorboard:
        tensorboard_logger = TensorBoardLogger(id=cfg.visualization.id,
                                               log_dir=to_absolute_path(cfg.visualization.log_dir),
                                               log_params=cfg.visualization.log_params)

    if cfg.checkpointing.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            cfg.checkpointing.continue_from = latest_checkpoint

    if cfg.checkpointing.continue_from:  # Starting from previous model
        state = TrainingState.load_state(state_path=to_absolute_path(cfg.checkpointing.continue_from))
        model = state.model
        if cfg.training.finetune:
            state.init_finetune_states(cfg.training.epochs)

        if main_proc and cfg.visualization.visdom:  # Add previous scores to visdom graph
            visdom_logger.load_previous_values(state.epoch, state.results)
        if main_proc and cfg.visualization.tensorboard:  # Previous scores to tensorboard logs
            tensorboard_logger.load_previous_values(state.epoch, state.results)
    else:
        # Initialise new model training

        with open(to_absolute_path(cfg.data.labels_path)) as label_file:
            labels = json.load(label_file)

        n_E = 4
        edge_model_list = []
        for i in range(n_E):
            if OmegaConf.get_type(cfg.model) is BiDirectionalConfig:
                model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                                   nb_layers=cfg.model.hidden_layers,
                                   labels=labels,
                                   rnn_type=supported_rnns[cfg.model.rnn_type.value],
                                   audio_conf=cfg.data.spect,
                                   bidirectional=True)
            elif OmegaConf.get_type(cfg.model) is UniDirectionalConfig:
                model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                                   nb_layers=cfg.model.hidden_layers,
                                   labels=labels,
                                   rnn_type=supported_rnns[cfg.model.rnn_type.value],
                                   audio_conf=cfg.data.spect,
                                   bidirectional=False,
                                   context=cfg.model.lookahead_context)
            edge_model_list.append(model)

        if OmegaConf.get_type(cfg.model) is BiDirectionalConfig:
            model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                               nb_layers=cfg.model.hidden_layers,
                               labels=labels,
                               rnn_type=supported_rnns[cfg.model.rnn_type.value],
                               audio_conf=cfg.data.spect,
                               bidirectional=True)
        elif OmegaConf.get_type(cfg.model) is UniDirectionalConfig:
            model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                               nb_layers=cfg.model.hidden_layers,
                               labels=labels,
                               rnn_type=supported_rnns[cfg.model.rnn_type.value],
                               audio_conf=cfg.data.spect,
                               bidirectional=False,
                               context=cfg.model.lookahead_context)
        else:
            raise ValueError("Model Config has not been specified correctly.")

        state = TrainingState(model=model)
        state.init_results_tracking(epochs=cfg.training.epochs)

    # Data setup
    evaluation_decoder = GreedyDecoder(model.labels)  # Decoder used for validation
    train_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                       manifest_filepath=to_absolute_path(cfg.data.train_manifest),
                                       labels=model.labels,
                                       normalize=True,
                                       augmentation_conf=cfg.data.augmentation)
    test_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                      manifest_filepath=to_absolute_path(cfg.data.val_manifest),
                                      labels=model.labels,
                                      normalize=True)
    if not is_distributed:
        train_sampler = DSRandomSampler(dataset=train_dataset,
                                        batch_size=cfg.data.batch_size,
                                        start_index=state.training_step)
    else:
        train_sampler = DSElasticDistributedSampler(dataset=train_dataset,
                                                    batch_size=cfg.data.batch_size,
                                                    start_index=state.training_step)
    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=cfg.data.num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=cfg.data.num_workers,
                                  batch_size=cfg.data.batch_size)

    for i, edge_model in enumerate(edge_model_list):
        device_num = i % torch.cuda.device_count()
        edge_model.cuda(device_num)
        print(f'model Num:{i}, device Num:{next(edge_model.parameters()).device}')
    model = model.to(device)
    print(f'central model at device:{next(model.parameters()).device}')

    # parameters = model.parameters()
    edge_optimizer_list = []
    for edge_model in edge_model_list:
        parameters = edge_model.parameters()
        if OmegaConf.get_type(cfg.optim) is SGDConfig:
            optimizer = torch.optim.SGD(parameters,
                                        lr=cfg.optim.learning_rate,
                                        momentum=cfg.optim.momentum,
                                        nesterov=True,
                                        weight_decay=cfg.optim.weight_decay)
        elif OmegaConf.get_type(cfg.optim) is AdamConfig:
            optimizer = torch.optim.AdamW(parameters,
                                          lr=cfg.optim.learning_rate,
                                          betas=cfg.optim.betas,
                                          eps=cfg.optim.eps,
                                          weight_decay=cfg.optim.weight_decay)
        else:
            raise ValueError("Optimizer has not been specified correctly.")

        edge_optimizer_list.append(optimizer)

    edge_model_list_ = []
    edge_optimizer_list_ = []
    for edge_model, optimizer in zip(edge_model_list, edge_optimizer_list):
        edge_model, optimizer = amp.initialize(edge_model, optimizer,
                                          enabled=not cfg.training.no_cuda,
                                          opt_level=cfg.apex.opt_level,
                                          loss_scale=cfg.apex.loss_scale)

        edge_model_list_.append(edge_model)
        edge_optimizer_list_.append(optimizer)
    edge_model_list = edge_model_list_
    edge_optimizer_list = edge_optimizer_list_
    del edge_model_list_, edge_optimizer_list_

    if state.optim_state is not None:
        optimizer.load_state_dict(state.optim_state)
    if state.amp_state is not None:
        amp.load_state_dict(state.amp_state)

    # Track states for optimizer/amp
    state.track_optim_state(optimizer)
    if not cfg.training.no_cuda:
        state.track_amp_state(amp)

    if is_distributed:
        model = DistributedDataParallel(model, device_ids=[device_id])
    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    criterion = CTCLoss()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(state.epoch, cfg.training.epochs):
        model.train()
        end = time.time()
        start_epoch_time = time.time()
        state.set_epoch(epoch=epoch)
        train_sampler.set_epoch(epoch=epoch)
        train_sampler.reset_training_step(training_step=state.training_step)

        distribution(model, edge_model_list)
        # pdb.set_trace()
        inputs_list = []
        input_sizes_list = []
        targets_list = []
        target_sizes_list = []
        for i, (data) in enumerate(train_loader, start=state.training_step):
            state.set_training_step(training_step=i)
            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            # measure data loading time
            data_time.update(time.time() - end)

            inputs_list.append(inputs)
            input_sizes_list.append(input_sizes)
            targets_list.append(targets)
            target_sizes_list.append(target_sizes)
            if len(inputs_list) < n_E:
                end = time.time()
                continue
            assert len(inputs_list) == n_E
            assert len(input_sizes_list) == n_E
            assert len(targets_list) == n_E
            assert len(target_sizes_list) == n_E
            # print('start training!')

            loss_list = []
            loss_value_list = []
            for inputs, input_sizes, targets, target_sizes, edge_model in zip(inputs_list, input_sizes_list, targets_list, target_sizes_list, edge_model_list):
                # print(device)
                device = next(edge_model.parameters()).device
                # To utilize default streams on different devices
                with torch.cuda.device(device):
                    # print(torch.cuda.current_stream())
                    inputs = inputs.to(device)
                    # targets = targets.to(device)

                    out, output_sizes = edge_model(inputs, input_sizes)
                    # print('model')
                    out = out.transpose(0, 1)  # TxNxH
                    # print('transpose')

                    # pdb.set_trace()
                    float_out = out.float()  # ensure float32 for loss
                    # print('float')
                    # print(float_out.device, targets.device, output_sizes.device, target_sizes.device)
                    loss = criterion(float_out, targets, output_sizes, target_sizes).to(device)
                    # print('criterion')
                    loss = loss / inputs.size(0)  # average the loss by minibatch
                    # print('loss')
                    loss_value = loss.item()
                    # print('loss_value')

                    loss_list.append(loss)
                    # print('loss_list')
                    loss_value_list.append(loss_value)
                    # print('loss_value_list')

            loss_value_list_ = []
            for loss, loss_value, optimizer in zip(loss_list, loss_value_list, edge_optimizer_list):
                device = loss.device
                with torch.cuda.device(device):
                    # Check to ensure valid loss was calculated
                    valid_loss, error = check_loss(loss, loss_value)
                    if valid_loss:
                        optimizer.zero_grad()

                        # compute gradient
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.optim.max_norm)
                        optimizer.step()
                    else:
                        print(error)
                        print('Skipping grad update')
                        loss_value = 0
                    loss_value_list_.append(loss_value)
            loss_value_list = loss_value_list_
            del loss_value_list_

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            for loss_value, inputs in zip(loss_value_list, inputs_list):
                state.avg_loss += loss_value
                losses.update(loss_value, inputs.size(0))

                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                    (epoch + 1), (i + 1), len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses))

            if main_proc and cfg.checkpointing.checkpoint_per_iteration:
                checkpoint_handler.save_iter_checkpoint_model(epoch=epoch, i=i, state=state)
            del loss, out, float_out

            inputs_list = []
            input_sizes_list = []
            targets_list = []
            target_sizes_list = []

        aggregation(edge_model_list, model)
        # pdb.set_trace()
        state.avg_loss /= len(train_dataset)

        epoch_time = time.time() - start_epoch_time
        print('Training Summary Epoch: [{0}]\t'
              'Time taken (s): {epoch_time:.0f}\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=state.avg_loss))

        device = next(model.parameters()).device
        with torch.no_grad():
            wer, cer, output_data = run_evaluation(test_loader=test_loader,
                                                   device=device,
                                                   model=model,
                                                   decoder=evaluation_decoder,
                                                   target_decoder=evaluation_decoder,
                                                   save_output='yes')

        state.add_results(epoch=epoch,
                          loss_result=state.avg_loss,
                          wer_result=wer,
                          cer_result=cer)

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))
        with open('validation_log.txt', 'a') as f:
            f.write('Validation Summary Epoch: [{0}]\t'
                  'Average WER {wer:.3f}\t'
                  'Average CER {cer:.3f}\n'.format(epoch + 1, wer=wer, cer=cer))
        save_pickle(output_data, 'output_data_'+str(epoch)+'.p')
        pdb.set_trace()
        if main_proc and cfg.visualization.visdom:
            visdom_logger.update(epoch, state.result_state)
        if main_proc and cfg.visualization.tensorboard:
            tensorboard_logger.update(epoch, state.result_state, model.named_parameters())

        if main_proc and cfg.checkpointing.checkpoint:  # Save epoch checkpoint
            checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state)
        # anneal lr
        for optimizer in edge_optimizer_list:
            for g in optimizer.param_groups:
                g['lr'] = g['lr'] / cfg.optim.learning_anneal
            print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

        if main_proc and (state.best_wer is None or state.best_wer > wer):
            checkpoint_handler.save_best_model(epoch=epoch, state=state)
            state.set_best_wer(wer)
            state.reset_avg_loss()
        state.reset_training_step()  # Reset training step for next epoch
Ejemplo n.º 18
0

if __name__ == '__main__':

	char_set = open('char_std_5990.txt', 'r', encoding='utf-8').readlines()
	char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] + ['卍'])
	n_class = len(char_set)

	model = crnn.CRNN(img_h, 1, n_class, 256)
	if torch.cuda.is_available and use_gpu:
		model.cuda()

	modelpath = opt.modelpath

	learning_rate = opt.learning_rate
	loss_func = CTCLoss()
	optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=opt.weight_decay)

	if os.path.exists(modelpath):
		print('Load model from "%s" ...' % modelpath)
		model.load_state_dict(torch.load(modelpath))
		print('Done!')
	k = 0
	losstotal = 0.0
	printinterval = opt.printinterval
	valinterval = opt.valinterval
	numinprint = 0
	# train
	for epoch in range(max_epoch):

		for i,(data,label) in enumerate(train_loader):
Ejemplo n.º 19
0
def main():
    args = parser.parse_args()
    save_folder = args.save_folder

    if args.visdom:
        from visdom import Visdom
        viz = Visdom()

        opts = [
            dict(title='Loss', ylabel='Loss', xlabel='Epoch'),
            dict(title='WER', ylabel='WER', xlabel='Epoch'),
            dict(title='CER', ylabel='CER', xlabel='Epoch')
        ]

        viz_windows = [None, None, None]
        loss_results, cer_results, wer_results = torch.Tensor(
            args.epochs), torch.Tensor(args.epochs), torch.Tensor(args.epochs)
        epochs = torch.range(1, args.epochs)

    try:
        os.makedirs(save_folder)
    except OSError as e:
        if e.errno == errno.EEXIST:
            print('Directory already exists.')
        else:
            raise
    criterion = CTCLoss()

    with open(args.labels_path) as label_file:
        labels = str(''.join(json.load(label_file)))

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window)

    train_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                       manifest_filepath=args.train_manifest,
                                       labels=labels,
                                       normalize=True)
    test_dataset = SpectrogramDataset(audio_conf=audio_conf,
                                      manifest_filepath=args.val_manifest,
                                      labels=labels,
                                      normalize=True)
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers)
    test_loader = AudioDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       num_classes=len(labels))
    decoder = ArgMaxDecoder(labels)
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    print(model)
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                lr=args.lr,
                                momentum=args.momentum,
                                nesterov=True)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(args.epochs):
        model.train()
        end = time.time()
        avg_loss = 0
        for i, (data) in enumerate(train_loader):
            inputs, targets, input_percentages, target_sizes = data
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = Variable(inputs)
            target_sizes = Variable(target_sizes)
            targets = Variable(targets)

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH

            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int())

            loss = criterion(out, targets, sizes, target_sizes)
            loss = loss / inputs.size(0)  # average the loss by minibatch

            loss_sum = loss.data.sum()
            inf = float("inf")
            if loss_sum == inf or loss_sum == -inf:
                print("WARNING: received an inf loss, setting loss value to 0")
                loss_value = 0
            else:
                loss_value = loss.data[0]

            avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # compute gradient
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm(model.parameters(), args.max_norm)
            # SGD step
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.silent:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                          (epoch + 1), (i + 1),
                          len(train_loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses))

        avg_loss /= len(train_loader)
        print('Training Summary Epoch: [{0}]\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1, loss=avg_loss))

        total_cer, total_wer = 0, 0
        for i, (data) in enumerate(test_loader):  # test
            inputs, targets, input_percentages, target_sizes = data

            inputs = Variable(inputs)

            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            if args.cuda:
                inputs = inputs.cuda()

            out = model(inputs)
            out = out.transpose(0, 1)  # TxNxH
            seq_length = out.size(0)
            sizes = Variable(input_percentages.mul_(int(seq_length)).int())

            decoded_output = decoder.decode(out.data, sizes)
            target_strings = decoder.process_strings(
                decoder.convert_to_strings(split_targets))
            wer, cer = 0, 0
            for x in range(len(target_strings)):
                wer += decoder.wer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x].split()))
                cer += decoder.cer(decoded_output[x],
                                   target_strings[x]) / float(
                                       len(target_strings[x]))
            total_cer += cer
            total_wer += wer

        wer = total_wer / len(test_loader.dataset)
        cer = total_cer / len(test_loader.dataset)
        wer *= 100
        cer *= 100

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.0f}\t'
              'Average CER {cer:.0f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if args.visdom:
            loss_results[epoch] = avg_loss
            wer_results[epoch] = wer
            cer_results[epoch] = cer
            epoch += 1
            x_axis = epochs[0:epoch]
            y_axis = [
                loss_results[0:epoch], wer_results[0:epoch],
                cer_results[0:epoch]
            ]
            for x in range(len(viz_windows)):
                if viz_windows[x] is None:
                    viz_windows[x] = viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        opts=opts[x],
                    )
                else:
                    viz.line(
                        X=x_axis,
                        Y=y_axis[x],
                        win=viz_windows[x],
                        update='replace',
                    )
        if args.epoch_save:
            file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch)
            torch.save(checkpoint(model, args, len(labels), epoch), file_path)
    torch.save(checkpoint(model, args, len(labels)), args.final_model_path)
Ejemplo n.º 20
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, base_lr, step_size, max_iter, batch_size, output_dir, test_epoch, test_init, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False

    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose([
        Rotation(),
        Translation(),
        # Scale(),
        Contrast(),
        # Grid_distortion(),
        Resize(size=(input_size[0], input_size[1]))
    ])
    seq_proj = [int(x) for x in seq_proj.split('x')]
    
    for fold_idx in range(24):
        train_mode = 'fold{0}_train'.format(fold_idx)
        val_mode = 'fold{0}_test'.format(fold_idx)
        
        if data_path is not None:
            data = TextDataset(data_path=data_path, mode=train_mode, transform=transform)
        else:
            data = TestDataset(transform=transform, abc=abc)
        
        net = load_model(data.get_abc(), seq_proj, backend, snapshot, cuda)
        optimizer = optim.Adam(net.parameters(), lr = base_lr, weight_decay=0.0001)
        lr_scheduler = StepLR(optimizer, step_size=step_size)
        # lr_scheduler = StepLR(optimizer, step_size=len(data)/batch_size*2)
        loss_function = CTCLoss()
        
        print(fold_idx)
        # continue
        
        acc_best = 0
        epoch_count = 0
        for epoch_idx in range(15):
            data_loader = DataLoader(data, batch_size=batch_size, num_workers=10, shuffle=True, collate_fn=text_collate)
            loss_mean = []
            iterator = tqdm(data_loader)
            iter_count = 0
            for sample in iterator:
                # for multi-gpu support
                if sample["img"].size(0) % len(gpu.split(',')) != 0:
                    continue
                optimizer.zero_grad()
                imgs = Variable(sample["img"])
                labels = Variable(sample["seq"]).view(-1)
                label_lens = Variable(sample["seq_len"].int())
                if cuda:
                    imgs = imgs.cuda()
                preds = net(imgs).cpu()
                pred_lens = Variable(Tensor([preds.size(0)] * batch_size).int())
                loss = loss_function(preds, labels, pred_lens, label_lens) / batch_size
                loss.backward()
                # nn.utils.clip_grad_norm(net.parameters(), 10.0)
                loss_mean.append(loss.data[0])
                status = "{}/{}; lr: {}; loss_mean: {}; loss: {}".format(epoch_count, lr_scheduler.last_iter, lr_scheduler.get_lr(), np.mean(loss_mean), loss.data[0])
                iterator.set_description(status)
                optimizer.step()
                lr_scheduler.step()
                iter_count += 1
            
            if True:
                logging.info("Test phase")
                
                net = net.eval()
                
#                 train_acc, train_avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
#                 if acc > 0.95:
#                     error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
#                     logging.info('Train: '+','.join(error_name))
#                 logging.info("acc: {}\tacc_best: {}; avg_ed: {}\n\n".format(train_acc, train_avg_ed))

                data.set_mode(val_mode)
                acc, avg_ed, error_idx = test(net, data, data.get_abc(), cuda, visualize=False)
                
                if acc > 0.95:
                    error_name = [data.config[data.mode][idx]["name"] for idx in error_idx]
                    logging.info('Val: '+','.join(error_name))
                
                
                
                net = net.train()
                data.set_mode(train_mode)
                
                if acc > acc_best:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_best"))
                    acc_best = acc
                
                if acc > 0.985:
                    if output_dir is not None:
                        torch.save(net.state_dict(), os.path.join(output_dir, train_mode+"_crnn_" + backend + "_" + str(data.get_abc()) + "_"+str(acc)))
                logging.info("train_acc: {}\t; avg_ed: {}\n\n".format(acc, acc_best, avg_ed))
                
                
            epoch_count += 1
    def _forward_1st_stage(
        self,
        original_input: np.ndarray,
        original_output: np.ndarray,
        local_batch_size: int,
        local_max_length: int,
        rescale: np.ndarray,
        input_mask: np.ndarray,
        real_lengths: np.ndarray,
    ) -> Tuple["torch.Tensor", "torch.Tensor", np.ndarray, "torch.Tensor",
               "torch.Tensor"]:
        """
        The forward pass of the first stage of the attack.

        :param original_input: Samples of shape (nb_samples, seq_length). Note that, sequences in the batch must have
                               equal lengths. A possible example of `original_input` could be:
                               `original_input = np.array([np.array([0.1, 0.2, 0.1]), np.array([0.3, 0.1, 0.0])])`.
        :param original_output: Target values of shape (nb_samples). Each sample in `original_output` is a string and
                                it may possess different lengths. A possible example of `original_output` could be:
                                `original_output = np.array(['SIXTY ONE', 'HELLO'])`.
        :param local_batch_size: Current batch size.
        :param local_max_length: Max length of the current batch.
        :param rescale: Current rescale coefficients.
        :param input_mask: Masks of true inputs.
        :param real_lengths: Real lengths of original sequences.
        :return: A tuple of (loss, local_delta, decoded_output, masked_adv_input)
                    - loss: The loss tensor of the first stage of the attack.
                    - local_delta: The delta of the current batch.
                    - decoded_output: Transcription output.
                    - masked_adv_input: Perturbed inputs.
        """
        import torch  # lgtm [py/repeated-import]
        from warpctc_pytorch import CTCLoss

        # Compute perturbed inputs
        local_delta = self.global_optimal_delta[:local_batch_size, :
                                                local_max_length]
        local_delta_rescale = torch.clamp(local_delta, -self.initial_eps,
                                          self.initial_eps).to(
                                              self.estimator.device)
        local_delta_rescale *= torch.tensor(rescale).to(self.estimator.device)
        adv_input = local_delta_rescale + torch.tensor(original_input).to(
            self.estimator.device)
        masked_adv_input = adv_input * torch.tensor(input_mask).to(
            self.estimator.device)

        # Transform data into the model input space
        inputs, targets, input_rates, target_sizes, batch_idx = self.estimator.transform_model_input(
            x=masked_adv_input.to(self.estimator.device),
            y=original_output,
            compute_gradient=False,
            tensor_input=True,
            real_lengths=real_lengths,
        )

        # Compute real input sizes
        input_sizes = input_rates.mul_(inputs.size()[-1]).int()

        # Call to DeepSpeech model for prediction
        outputs, output_sizes = self.estimator.model(
            inputs.to(self.estimator.device),
            input_sizes.to(self.estimator.device))
        outputs_ = outputs.transpose(0, 1)
        float_outputs = outputs_.float()

        # Loss function
        criterion = CTCLoss()
        loss = criterion(float_outputs, targets, output_sizes,
                         target_sizes).to(self.estimator.device)
        loss = loss / inputs.size(0)

        # Compute transcription
        decoded_output, _ = self.estimator.decoder.decode(
            outputs, output_sizes)
        decoded_output = [do[0] for do in decoded_output]
        decoded_output = np.array(decoded_output)

        # Rearrange to the original order
        decoded_output_ = decoded_output.copy()
        decoded_output[batch_idx] = decoded_output_

        return loss, local_delta, decoded_output, masked_adv_input, local_delta_rescale
Ejemplo n.º 22
0
def main(cfg, cuda=torch.cuda.is_available()):
    ### flush cfg to output log file:
    tqdm.write(str(cfg), file=cfg['logfile'])
    tqdm.write('-' * 80, file=cfg['logfile'])

    ### define dataloader factory:
    def get_iterator():
        # set up dataloader config:
        datasets = cfg['data_paths']
        pin_mem = cuda
        nworkers = cfg['num_workers']

        # (possibly) concatenate datasets together:
        ds = SeqTensorDataset(torch.load(datasets[0][0]),
                              torch.load(datasets[0][1]),
                              torch.load(datasets[0][2]),
                              torch.load(datasets[0][3]))
        for dataset in datasets[1:]:
            ds += SeqTensorDataset(torch.load(dataset[0]),
                                   torch.load(dataset[1]),
                                   torch.load(dataset[2]),
                                   torch.load(dataset[3]))

        # return a dataloader iterating over datasets; pagelock memory location if GPU detected:
        return DataLoader(ds,
                          batch_size=cfg['batch_size'],
                          shuffle=True,
                          num_workers=nworkers,
                          collate_fn=sequence_collate_fn,
                          pin_memory=pin_mem)

    ### build RawCTCNet model:
    in_dim = 1
    layers = [(256, 256, d, 3)
              for d in [1, 2, 4, 8, 16, 32, 64]] * cfg['num_stacks']
    num_labels = 5
    out_dim = 512
    network = RawCTCNet(in_dim,
                        num_labels,
                        layers,
                        out_dim,
                        input_kw=1,
                        input_dil=1,
                        positions=True,
                        softmax=False,
                        causal=False,
                        batch_norm=True)
    print("Constructed network.")
    if cuda:
        print("CUDA detected; placed network on GPU.")
        network.cuda()
    if cfg['model'] is not None:
        print("Loading model file...")
        try:
            network.load_state_dict(torch.load(cfg['model']))
        except:
            print(
                "ERR: could not restore model. Check model datatype/dimensions."
            )

    ### build CTC loss function and model evaluation function:
    ctc_loss_fn = CTCLoss()
    print("Constructed CTC loss function.")
    maybe_gpu = lambda tsr, has_cuda: tsr if not has_cuda else tsr.cuda()

    def model_loss(sample):
        # unpack inputs and wrap in Variables:
        signals_, signal_lengths_, sequences_, sequence_lengths_ = sample
        signals = Variable(maybe_gpu(signals_.permute(0, 2, 1), cuda),
                           volatile=True)  # BxTxD => BxDxT
        signal_lengths = Variable(signal_lengths_, volatile=True)
        sequences = Variable(concat_labels(sequences_, sequence_lengths_),
                             volatile=True)
        sequence_lengths = Variable(sequence_lengths_, volatile=True)
        # compute predicted labels:
        transcriptions = network(signals).permute(2, 0,
                                                  1)  # Permute: BxDxT => TxBxD
        # compute CTC loss and return:
        loss = ctc_loss_fn(transcriptions, sequences.int(),
                           signal_lengths.int(), sequence_lengths.int())
        return loss, transcriptions

    ### build beam search decoder:
    beam_labels = [' ', 'A', 'G', 'C', 'T']
    beam_blank_id = 0
    beam_decoder = CTCBeamDecoder(beam_labels,
                                  beam_width=100,
                                  blank_id=beam_blank_id,
                                  num_processes=cfg['num_workers'])
    print("Constructed CTC beam search decoder.")

    ### build engine, meters, and hooks:
    engine = Engine()

    # Wrap a tqdm meter around the losses:
    def on_start(state):
        network.eval()
        state['iterator'] = tqdm(state['iterator'])

    # (Currently don't do anything w/r/t the sample.)
    def on_sample(state):
        pass

    # occasionally log the loss value and perform beam search decoding:
    def on_forward(state):
        if (state['t'] % cfg['print_every'] == 0):
            # log the ctc loss:
            tqdm.write("Step {0} | Loss: {1}".format(state['t'],
                                                     state['loss'].data[0],
                                                     file=cfg['logfile']))
            # beam search decoding:
            _, logit_lengths_t, seq_t, seq_lengths_t = state['sample']
            scores = mask_padding(state['output'].permute(1, 0, 2),
                                  logit_lengths_t,
                                  fill_logit_idx=0)
            logits = F.softmax(scores, dim=2)
            _nt_dict_ = {0: ' ', 1: 'A', 2: 'G', 3: 'C', 4: 'T'}

            def convert_to_string(toks, voc, num):
                try:
                    nt = ''.join([voc[t] for t in toks[0:num]])
                except:
                    nt = ''
                return nt

            try:
                true_nts = labels2strings(seq_t, lookup=_nt_dict_)
                amax_nts = labels2strings(argmax_decode(logits),
                                          lookup=_nt_dict_)
                beam_result, beam_scores, beam_times, beam_lengths = beam_decoder.decode(
                    logits.data)
                pred_nts = [
                    convert_to_string(beam_result[k][0], _nt_dict_,
                                      beam_lengths[k][0])
                    for k in range(len(beam_result))
                ]
                for i in range(min(len(true_nts), len(pred_nts))):
                    tqdm.write("True Seq: {0}".format(true_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("Beam Seq: {0}".format(pred_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("Amax Seq: {0}".format(amax_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write(
                        ("- " * 10 + "Local Beam Alignment" + " -" * 10),
                        file=cfg['logfile'])
                    tqdm.write(ssw(true_nts[i], pred_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("= " * 40, file=cfg['logfile'])
            except:
                tqdm.write("(WARN: Could not parse batch; skipping...)",
                           file=cfg['logfile'])

    # (Currently don't do anything at end of epoch.)
    def on_end(state):
        pass

    print("Constructed engine. Running validation loop...")

    ### run validation loop:
    engine.hooks['on_start'] = on_start
    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_end'] = on_end
    engine.test(model_loss, get_iterator())
Ejemplo n.º 23
0
def main():
    parser = argparse.ArgumentParser(description="Attn Encoder")
    parser.add_argument("--img", type=str, help="image dir")
    parser.add_argument("--prior", type=str, help="prior dir")
    parser.add_argument("--csv", type=str, help="csv dir")
    parser.add_argument("--conf", type=str, help="config file")
    parser.add_argument("--output", type=str, help="output dir")
    parser.add_argument("--pretrain", type=str, default=None, help="pretrain path")
    parser.add_argument("--cont", action="store_true", help="continue training")
    parser.add_argument("--epoch", type=int, default=1, help="epoch")
    parser.add_argument("--optim_step_size", type=int, default=30, help="lr decay step size")
    parser.add_argument("--optim_gamma", type=float, default=0.1, help="lr decay rate")
    parser.add_argument("--scaling", action="store_true", help="data augmentation (scaling)")
    parser.add_argument("--img_scale", type=float, default=1., nargs="+", help="image scales")
    parser.add_argument("--map_scale", type=int, default=13, nargs="+", help="map scales")
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    best_path = os.path.join(args.output, "best.pth")
    latest_path = os.path.join(args.output, "latest.pth")
    log = os.path.join(args.output, "log")
    hyper_path = os.path.join(args.output, "hyper.pth")

    config = configparser.ConfigParser()
    config.read(args.conf)
    model_cfg, lang_cfg, img_cfg = config['MODEL'], config['LANG'], config['IMAGE']
    hidden_size, attn_size, n_layers = model_cfg.getint('hidden_size'), model_cfg.getint('attn_size'), model_cfg.getint('n_layers')
    prior_gamma = model_cfg.getfloat('prior_gamma')
    learning_rate = model_cfg.getfloat('learning_rate')
    batch_size = model_cfg.getint('batch_size')
    char_list = lang_cfg['chars'] # " '&.@acbedgfihkjmlonqpsrutwvyxz"
    immean, imstd = [float(x) for x in config['IMAGE']['immean'].split(',')], [float(x) for x in config['IMAGE']['imstd'].split(',')] # [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    upper_len = model_cfg.getint('upper_length')
    clip = model_cfg.getfloat('clip')
    save_interval = model_cfg.getint('interval')
    epochs = args.epoch
    optim_step_size, optim_gamma = args.optim_step_size, args.optim_gamma

    train_csv, dev_csv = os.path.join(args.csv, 'train.csv'), os.path.join(args.csv, 'dev.csv')

    device, cpu = torch.device('cuda'), torch.device('cpu')

    vocab_map, inv_vocab_map, char_list = utils.get_ctc_vocab(char_list)

    if type(args.img_scale) == list and type(args.map_scale) == list:
        scale_range, hw_range = args.img_scale, [(x, x) for x in args.map_scale]
    elif type(args.img_scale) == float and type(args.map_scale) == int:
        scale_range, hw_range = [args.img_scale], [(args.map_scale, args.map_scale)]
    else:
        raise AttributeError('scale: list or float/int')

    if not args.scaling:
        tsfm_train = transforms.Compose([dataset.ToTensor(device), dataset.Rescale(scale_range, hw_range, origin_scale=True), dataset.Normalize(immean, imstd, device)])
        tsfm_test = transforms.Compose([dataset.ToTensor(device), dataset.Rescale(scale_range, hw_range, origin_scale=True), dataset.Normalize(immean, imstd, device)])
    else:
        # scale_range = [1] # [1, 0.8, 1.2] # [1, 0.8]
        # hw_range = [(13, 13)]  # [(13, 13), (10, 10), (15, 15)] # [(13, 13), (10, 10)]
        tsfm_train = transforms.Compose([dataset.ToTensor(device), dataset.Rescale(scale_range, hw_range), dataset.Normalize(immean, imstd, device)])
        tsfm_test = transforms.Compose([dataset.ToTensor(device), dataset.Rescale(scale_range, hw_range, origin_scale=True), dataset.Normalize(immean, imstd, device)])

    sld_train_data = dataset.SLData(args.img, args.prior, train_csv, vocab_map, transform=tsfm_train, upper_len=upper_len)
    sld_dev_data = dataset.SLData(args.img, args.prior, dev_csv, vocab_map, transform=tsfm_test, upper_len=float('inf')) # dataset.Rescale([1], [(13, 13)])

    encoder = AttnEncoder(hidden_size=hidden_size, attn_size=attn_size,
                          output_size=len(char_list), n_layers=n_layers,
                          prior_gamma=prior_gamma, pretrain=args.pretrain)
    encoder.to(device)
    if torch.cuda.device_count() > 1:
        print('Using %d GPUs' % (torch.cuda.device_count()))
        encoder = nn.DataParallel(encoder)
    hypers = {'step': 0, 'epoch': 0, 'best_dev_acc': -1, 'perm': np.random.permutation(len(sld_train_data)).tolist()}

    if args.cont:
        print("Load %s, %s" % (latest_path, hyper_path))
        encoder.load_state_dict(torch.load(latest_path))
        try:
            with open(hyper_path, 'rb') as fo:
                hypers = pickle.load(fo)
        except Exception as err:
            print("Error loading %s: %s" % (hyper_path, err))
            hypers = {'step': 0, 'epoch': 0, 'best_dev_acc': -1, 'perm': np.random.permutation(len(sld_train_data)).tolist()}

    train_loader = tud.DataLoader(sld_train_data, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn_ctc)
    dev_loader = tud.DataLoader(sld_dev_data, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn_ctc)

    print('Optimizer, decay %.5f after %d epochs' % (optim_gamma, optim_step_size))
    cnn_optimizer = optim.SGD(encoder.conv.parameters(), lr=learning_rate)
    lstm_optimizer = optim.SGD(list(encoder.encoder_cell.parameters())+list(encoder.lt.parameters()), lr=learning_rate)
    cnn_scheduler = optim.lr_scheduler.StepLR(cnn_optimizer, step_size=optim_step_size, gamma=optim_gamma)
    lstm_scheduler = optim.lr_scheduler.StepLR(lstm_optimizer, step_size=optim_step_size, gamma=optim_gamma)

    decoder = Decoder(char_list)
    ctc_loss = CTCLoss() # normalize over batch

    print('%d training epochs' % (epochs))
    for ep in range(epochs):
        cnn_scheduler.step()
        lstm_scheduler.step()
        if ep < hypers['epoch']:
            continue
        for p in cnn_optimizer.param_groups:
            print('CNN', p['lr'])
        for p in lstm_optimizer.param_groups:
            print('LSTM', p['lr'])
        train(encoder, train_loader, clip, hypers, cnn_optimizer, lstm_optimizer, ctc_loss, decoder, log, latest_path, hyper_path, device, save_interval)

        dl, dacc = evaluate(encoder, dev_loader, ctc_loss, decoder, device)
        pcont = 'Epoch %d, dev loss: %.3f, dev acc (LEV): %.3f' % (ep, dl, dacc)
        print(pcont)
        with open(log, 'a+') as fo:
            fo.write(pcont+"\n")
        # save model and hyperparameter setting
        hypers['epoch'] = ep
        if hypers['best_dev_acc'] < dacc:
            hypers['best_dev_acc'] = dacc
            with open(best_path, 'wb') as fo:
                torch.save(encoder.state_dict(), fo)
        with open(hyper_path, 'wb') as fo:
            pickle.dump(hypers, fo)
    return
Ejemplo n.º 24
0
def train(cfg):
    # Set seeds for determinism
    torch.manual_seed(cfg.training.seed)
    torch.cuda.manual_seed_all(cfg.training.seed)
    np.random.seed(cfg.training.seed)
    random.seed(cfg.training.seed)

    main_proc = True
    device = torch.device("cpu" if cfg.training.no_cuda else "cuda")

    is_distributed = os.environ.get(
        "LOCAL_RANK")  # If local rank exists, distributed env

    if is_distributed:
        # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops
        # because NCCL uses a spin-lock on the device. Set this env var and
        # to enable a watchdog thread that will destroy stale NCCL communicators
        os.environ["NCCL_BLOCKING_WAIT"] = "1"

        device_id = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(device_id)
        print(f"Setting CUDA Device to {device_id}")

        dist.init_process_group(backend=cfg.training.dist_backend)
        main_proc = device_id == 0  # Main process handles saving of models and reporting

    checkpoint_handler = CheckpointHandler(
        save_folder=to_absolute_path(cfg.checkpointing.save_folder),
        best_val_model_name=cfg.checkpointing.best_val_model_name,
        checkpoint_per_iteration=cfg.checkpointing.checkpoint_per_iteration,
        save_n_recent_models=cfg.checkpointing.save_n_recent_models)

    if main_proc and cfg.visualization.visdom:
        visdom_logger = VisdomLogger(id=cfg.visualization.id,
                                     num_epochs=cfg.training.epochs)
    if main_proc and cfg.visualization.tensorboard:
        tensorboard_logger = TensorBoardLogger(
            id=cfg.visualization.id,
            log_dir=to_absolute_path(cfg.visualization.log_dir),
            log_params=cfg.visualization.log_params)

    if cfg.checkpointing.load_auto_checkpoint:
        latest_checkpoint = checkpoint_handler.find_latest_checkpoint()
        if latest_checkpoint:
            cfg.checkpointing.continue_from = latest_checkpoint

    if cfg.checkpointing.continue_from:  # Starting from previous model
        state = TrainingState.load_state(
            state_path=to_absolute_path(cfg.checkpointing.continue_from))
        model = state.model
        if cfg.training.finetune:
            state.init_finetune_states(cfg.training.epochs)

        if main_proc and cfg.visualization.visdom:  # Add previous scores to visdom graph
            visdom_logger.load_previous_values(state.epoch, state.results)
        if main_proc and cfg.visualization.tensorboard:  # Previous scores to tensorboard logs
            tensorboard_logger.load_previous_values(state.epoch, state.results)
    else:
        # Initialise new model training
        with open(to_absolute_path(cfg.data.labels_path)) as label_file:
            labels = json.load(label_file)

        audio_conf = dict(sample_rate=cfg.data.sample_rate,
                          window_size=cfg.data.window_size,
                          window_stride=cfg.data.window_stride,
                          window=cfg.data.window)
        if cfg.augmentation.noise_dir:
            audio_conf += dict(noise_dir=to_absolute_path(
                cfg.augmentation.noise_dir),
                               noise_prob=cfg.augmentation.noise_prob,
                               noise_levels=(cfg.augmentation.noise_min,
                                             cfg.augmentation.noise_max))

        rnn_type = cfg.model.rnn_type.lower()
        assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
        model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size,
                           nb_layers=cfg.model.hidden_layers,
                           labels=labels,
                           rnn_type=supported_rnns[rnn_type],
                           audio_conf=audio_conf,
                           bidirectional=cfg.model.bidirectional)

        state = TrainingState(model=model)
        state.init_results_tracking(epochs=cfg.training.epochs)

    # Data setup
    evaluation_decoder = GreedyDecoder(
        model.labels)  # Decoder used for validation
    train_dataset = SpectrogramDataset(
        audio_conf=model.audio_conf,
        manifest_filepath=to_absolute_path(cfg.data.train_manifest),
        labels=model.labels,
        normalize=True,
        speed_volume_perturb=cfg.augmentation.speed_volume_perturb,
        spec_augment=cfg.augmentation.spec_augment)
    test_dataset = SpectrogramDataset(audio_conf=model.audio_conf,
                                      manifest_filepath=to_absolute_path(
                                          cfg.data.val_manifest),
                                      labels=model.labels,
                                      normalize=True,
                                      speed_volume_perturb=False,
                                      spec_augment=False)
    if not is_distributed:
        train_sampler = DSRandomSampler(dataset=train_dataset,
                                        batch_size=cfg.data.batch_size,
                                        start_index=state.training_step)
    else:
        train_sampler = DSElasticDistributedSampler(
            dataset=train_dataset,
            batch_size=cfg.data.batch_size,
            start_index=state.training_step)
    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=cfg.data.num_workers,
                                   batch_sampler=train_sampler)
    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=cfg.data.num_workers,
                                  batch_size=cfg.data.batch_size)

    model = model.to(device)
    parameters = model.parameters()
    if cfg.optimizer.adam:
        optimizer = torch.optim.AdamW(parameters,
                                      lr=cfg.optimizer.learning_rate,
                                      betas=cfg.optimizer.betas,
                                      eps=cfg.optimizer.eps,
                                      weight_decay=cfg.optimizer.weight_decay)
    else:
        optimizer = torch.optim.SGD(parameters,
                                    lr=cfg.optimizer.learning_rate,
                                    momentum=cfg.optimizer.momentum,
                                    nesterov=True,
                                    weight_decay=cfg.optimizer.weight_decay)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=cfg.apex.opt_level,
                                      loss_scale=cfg.apex.loss_scale)
    if state.optim_state is not None:
        optimizer.load_state_dict(state.optim_state)
        amp.load_state_dict(state.amp_state)

    # Track states for optimizer/amp
    state.track_optim_state(optimizer)
    state.track_amp_state(amp)

    if is_distributed:
        model = DistributedDataParallel(model, device_ids=[device_id])
    print(model)
    print("Number of parameters: %d" % DeepSpeech.get_param_size(model))

    criterion = CTCLoss()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    for epoch in range(state.epoch, cfg.training.epochs):
        model.train()
        end = time.time()
        start_epoch_time = time.time()
        state.set_epoch(epoch=epoch)
        train_sampler.set_epoch(epoch=epoch)
        train_sampler.reset_training_step(training_step=state.training_step)
        for i, (data) in enumerate(train_loader, start=state.training_step):
            state.set_training_step(training_step=i)
            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            # measure data loading time
            data_time.update(time.time() - end)
            inputs = inputs.to(device)

            out, output_sizes = model(inputs, input_sizes)
            out = out.transpose(0, 1)  # TxNxH

            float_out = out.float()  # ensure float32 for loss
            loss = criterion(float_out, targets, output_sizes,
                             target_sizes).to(device)
            loss = loss / inputs.size(0)  # average the loss by minibatch
            loss_value = loss.item()

            # Check to ensure valid loss was calculated
            valid_loss, error = check_loss(loss, loss_value)
            if valid_loss:
                optimizer.zero_grad()

                # compute gradient
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               cfg.optimizer.max_norm)
                optimizer.step()
            else:
                print(error)
                print('Skipping grad update')
                loss_value = 0

            state.avg_loss += loss_value
            losses.update(loss_value, inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                      (epoch + 1), (i + 1),
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))

            if main_proc and cfg.checkpointing.checkpoint_per_iteration:
                checkpoint_handler.save_iter_checkpoint_model(epoch=epoch,
                                                              i=i,
                                                              state=state)
            del loss, out, float_out

        state.avg_loss /= len(train_dataset)

        epoch_time = time.time() - start_epoch_time
        print('Training Summary Epoch: [{0}]\t'
              'Time taken (s): {epoch_time:.0f}\t'
              'Average Loss {loss:.3f}\t'.format(epoch + 1,
                                                 epoch_time=epoch_time,
                                                 loss=state.avg_loss))

        with torch.no_grad():
            wer, cer, output_data = evaluate(test_loader=test_loader,
                                             device=device,
                                             model=model,
                                             decoder=evaluation_decoder,
                                             target_decoder=evaluation_decoder)

        state.add_results(epoch=epoch,
                          loss_result=state.avg_loss,
                          wer_result=wer,
                          cer_result=cer)

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if main_proc and cfg.visualization.visdom:
            visdom_logger.update(epoch, state.result_state)
        if main_proc and cfg.visualization.tensorboard:
            tensorboard_logger.update(epoch, state.result_state,
                                      model.named_parameters())

        if main_proc and cfg.checkpointing.checkpoint:  # Save epoch checkpoint
            checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state)
        # anneal lr
        for g in optimizer.param_groups:
            g['lr'] = g['lr'] / cfg.optimizer.learning_anneal
        print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr']))

        if main_proc and (state.best_wer is None or state.best_wer > wer):
            checkpoint_handler.save_best_model(epoch=epoch, state=state)
            state.set_best_wer(wer)
            state.reset_avg_loss()
        state.reset_training_step()  # Reset training step for next epoch
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    print("찍어보기")
    print(opt, type(opt))
    print("찍어보기 끝")
    #     Namespace(FT=False, FeatureExtraction='ResNet', PAD=False, Prediction='Attn', SequenceModeling='BiLSTM', Transformation='TPS', adam=False, baiduCTC=False, batch_max_length=25, batch_ratio=['1'], batch_size=192, beta1=0.9, character='0123456789abcdefghijklmnopqrstuvwxyz', data_filtering_off=False, eps=1e-08, exp_name='TPS-ResNet-BiLSTM-Attn-Seed1111', grad_clip=5, hidden_size=256, imgH=32, imgW=100, input_channel=1, lr=1, manualSeed=1111, num_fiducial=20, num_gpu=1, num_iter=300000, output_channel=512, rgb=False, rho=0.95, saved_model='', select_data=['ST'], sensitive=False, total_data_usage_ratio='1.0', train_data='/data/mrjaehong/data_set/imdb/data_lmdb_release', valInterval=1, valid_data='/data/mrjaehong/data_set/imdb/data_lmdb_release/validation', workers=4) <class 'argparse.Namespace'>

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Ejemplo n.º 26
0
def main(opts):

  nclass = len(alphabet) + 1
  model_name = 'E2E-MLT'
  net = OwnModel(attention=True, nclass=nclass)
  print("Using {0}".format(model_name))
  if opts.cuda:
    net.cuda()
  learning_rate = opts.base_lr
  optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)

  ### 第一种:只修改conv11的维度 
  # model_dict = net.state_dict()
  # if os.path.exists(opts.model):
  #     # 载入预训练模型
  #     print('loading pretrained model from %s' % opts.model)
  #     # pretrained_model = OwnModel(attention=True, nclass=7325)
  #     pretrained_model = ModelResNetSep2(attention=True, nclass=7500)
  #     pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
  #     pretrained_dict = pretrained_model.state_dict()
  #
  #     pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
  #     # 2. overwrite entries in the existing state dict
  #     model_dict.update(pretrained_dict)
  #     # 3. load the new state dict
  #     net.load_state_dict(model_dict)

  ### 第二种:直接接着前面训练
  if os.path.exists(opts.model):
    print('loading model from %s' % args.model)
    step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
  ### 
  
  step_start = 0
  net.train()

  converter = strLabelConverter(alphabet)
  ctc_loss = CTCLoss()

  e2edata = E2Edataset(train_list=opts.train_list)
  e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=4, shuffle=True, collate_fn=E2Ecollate)
  
  train_loss = 0
  bbox_loss, seg_loss, angle_loss = 0., 0., 0.
  cnt = 0
  ctc_loss_val = 0
  ctc_loss_val2 = 0
  box_loss_val = 0
  gt_g_target = 0
  gt_g_proc = 0
  
  
  for step in range(step_start, opts.max_iters):

    loss = 0

    # batch
    images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
    im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
    # im_data = torch.from_numpy(images).type(torch.FloatTensor).permute(0, 3, 1, 2).cuda()       # permute(0,3,1,2)和cuda的先后顺序有影响
    start = timeit.timeit()
    try:
      seg_pred, roi_pred, angle_pred, features = net(im_data)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
    end = timeit.timeit()
    
    # for EAST loss
    smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
    training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
    angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
    geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
    
    try:
      loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
    except:
      import sys, traceback
      traceback.print_exc(file=sys.stdout)
      continue
      
    bbox_loss += net.box_loss_value.data.cpu().numpy() 
    seg_loss += net.segm_loss_value.data.cpu().numpy()
    angle_loss += net.angle_loss_value.data.cpu().numpy()  
    train_loss += loss.data.cpu().numpy()
    
       
    try:
      # 10000步之前都是用文字的标注区域训练的
      if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
    # ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
    ctcl= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
    gt_target = 1
    gt_proc = 1

    ctc_loss_val += ctcl.data.cpu().numpy()[0]
    loss = ctcl
    gt_g_target = gt_target
    gt_g_proc = gt_proc
    train_loss += ctcl.item()
      
      # -训练ocr识别部分的时候,采用一个data_generater生成
      # imageso, labels, label_length = next(dg_ocr)          # 其中应该有对倾斜文本的矫正
      # im_data_ocr = net_utils.np_to_variable(imageso, is_cuda=opts.cuda).permute(0, 3, 1, 2)
      # features = net.forward_features(im_data_ocr)
      # labels_pred = net.forward_ocr(features)
      # probs_sizes =  torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
      # label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
      # labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
      # loss_ocr = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data_ocr.size(0) * 0.5
      # loss_ocr.backward()
      # ctc_loss_val2 += loss_ocr.item()

      net.zero_grad()
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    except:
Ejemplo n.º 27
0
 def _default_loss(self, yp, yp_lens, y, y_lens):  # -> ctc_loss: scalar tensor
     criterion = CTCLoss(size_average=True)
     yp = yp.permute(2, 0, 1)  # B * V * T -> T * B * V
     loss = criterion(yp, y, yp_lens, y_lens)
     return loss
Ejemplo n.º 28
0
def main(cfg, cuda_avail=torch.cuda.is_available()):
    ### flush cfg to output log file:
    tqdm.write(str(cfg), file=cfg['logfile'])
    tqdm.write('-' * 80)

    ### define function that returns a data loader:
    def get_iterator(mode='train'):
        # choose between train/valid data based on `mode`:
        if mode == 'train':
            datasets = cfg['train_data_paths']
            pin_memory_flag = (cuda_avail and cfg['cuda'])
            num_workers_setting = 4
        if mode == 'valid':
            datasets = cfg['valid_data_paths']
            pin_memory_flag = False
            num_workers_setting = 1

        # form a (possibly concatenated) dataset:
        ds = SeqTensorDataset(torch.load(datasets[0][0]),
                              torch.load(datasets[0][1]),
                              torch.load(datasets[0][2]),
                              torch.load(datasets[0][3]))
        for dataset in datasets[1:]:
            ds += SeqTensorDataset(torch.load(dataset[0]),
                                   torch.load(dataset[1]),
                                   torch.load(dataset[2]),
                                   torch.load(dataset[3]))

        # return a loader that iterates over the dataset of choice; pagelock the memory location if GPU detected:
        return DataLoader(ds,
                          batch_size=cfg['batch_size'],
                          shuffle=True,
                          num_workers=num_workers_setting,
                          collate_fn=sequence_collate_fn,
                          pin_memory=pin_memory_flag)

    ### build RawCTCNet model:
    in_dim = 1
    layers = [(256, 256, d, 3)
              for d in [1, 2, 4, 8, 16, 32, 64]] * cfg['num_stacks']
    num_labels = 5
    out_dim = 512
    network = RawCTCNet(in_dim,
                        num_labels,
                        layers,
                        out_dim,
                        input_kw=1,
                        input_dil=1,
                        positions=True,
                        softmax=False,
                        causal=False,
                        batch_norm=True)
    print("Constructed network.")
    if (cuda_avail and cfg['cuda']):
        print("CUDA detected; placed network on GPU.")
        network.cuda()
    if cfg['model'] is not None:
        print("Loading model file...")
        try:
            network.load_state_dict(torch.load(cfg['model']))
        except:
            print(
                "ERR: could not restore model. Check model datatype/dimensions."
            )

    ### build CTCLoss and model evaluation function:
    ctc_loss_fn = CTCLoss()
    print("Constructed CTC loss function.")
    maybe_gpu = lambda tsr, has_cuda: tsr if not has_cuda else tsr.cuda()

    #--- this function performs the gradient descent in synchronous batched mode:
    def batch_model_loss(sample):
        # unpack inputs and wrap as `torch.autograd.Variable`s:
        signals_, signal_lengths_, sequences_, sequence_lengths_ = sample
        signals = Variable(
            maybe_gpu(signals_.permute(0, 2, 1),
                      (cuda_avail and cfg['cuda'])))  # BxTxD => BxDxT
        signal_lengths = Variable(signal_lengths_)
        sequences = Variable(concat_labels(sequences_, sequence_lengths_))
        sequence_lengths = Variable(sequence_lengths_)
        # compute predicted labels:
        transcriptions = network(signals).permute(2, 0,
                                                  1)  # Permute: BxDxT => TxBxD
        # compute CTC loss and return:
        loss = ctc_loss_fn(transcriptions, sequences.int(),
                           signal_lengths.int(), sequence_lengths.int())
        loss.backward()
        return loss, transcriptions

    #--- for evaluation-mode, batch-parallel:
    def batch_model_eval(sample):
        # unpack inputs and wrap as `torch.autograd.Variable`s:
        signals_, signal_lengths_, sequences_, sequence_lengths_ = sample
        signals = Variable(maybe_gpu(signals_.permute(0, 2, 1),
                                     (cuda_avail and cfg['cuda'])),
                           volatile=True)  # BxTxD => BxDxT
        signal_lengths = Variable(signal_lengths_, volatile=True)
        sequences = Variable(concat_labels(sequences_, sequence_lengths_),
                             volatile=True)
        sequence_lengths = Variable(sequence_lengths_, volatile=True)
        # compute predicted labels:
        transcriptions = network(signals).permute(2, 0,
                                                  1)  # Permute: BxDxT => TxBxD
        # compute CTC loss and return:
        loss = ctc_loss_fn(transcriptions, sequences.int(),
                           signal_lengths.int(), sequence_lengths.int())
        return loss, transcriptions

    #--- asynchronous gradient accumulation mode
    # compute target seqs/losses sequentially over each example, average gradients
    def async_model_loss(sample):
        # unpack inputs, optionally place on CUDA:
        signals_, signal_lengths_, sequences_, sequence_lengths_ = sample
        signals = maybe_gpu(signals_.permute(0, 2, 1),
                            (cuda_avail and cfg['cuda']))  # BxTxD => BxDxT

        # sequential compute over the batch:
        total_loss = 0.0
        transcriptions_list = []
        bsz = signals.size(0)
        for k in range(bsz):
            # fetch k-th input from batched sample and wrap as Variable:
            sig_k_scalar = signal_lengths_[k]
            seq_k_scalar = sequence_lengths_[k]
            sig_k_length = Variable(torch.IntTensor([sig_k_scalar]))
            seq_k_length = Variable(torch.IntTensor([seq_k_scalar]))
            signal_k = Variable(signals[k, :, :sig_k_scalar].unsqueeze(0))
            sequence_k = Variable(sequences_[k, :seq_k_scalar].unsqueeze(0))

            # compute transcription output:
            trans_k = network(signal_k).permute(2, 0,
                                                1)  # Permute: 1xDxT => Tx1xD

            # compute normalized CTC loss and accumulate gradient:
            loss = ctc_loss_fn(trans_k, sequence_k.int(), sig_k_length.int(),
                               seq_k_length.int())
            loss.backward()
            total_loss += loss
            transcriptions_list.append(trans_k)

        # combine transcriptions back into a batch and return:
        max_length = max([t.size(0) for t in transcriptions_list])
        transcriptions = Variable(torch.zeros(max_length, bsz, num_labels))
        for j, tr in enumerate(transcriptions_list):
            transcriptions[0:tr.size(0), j, :] = tr[:, 0, :]
        return total_loss, transcriptions

    #--- asynchronous gradient accumulation mode
    # compute target seqs/losses sequentially over each example, average gradients
    def async_model_eval(sample):
        # unpack inputs, optionally place on CUDA:
        signals_, signal_lengths_, sequences_, sequence_lengths_ = sample
        signals = maybe_gpu(signals_.permute(0, 2, 1),
                            (cuda_avail and cfg['cuda']))  # BxTxD => BxDxT

        # sequential compute over the batch:
        total_loss = 0.0
        transcriptions_list = []
        bsz = signals.size(0)
        for k in range(bsz):
            # fetch k-th input from batched sample and wrap as Variable:
            sig_k_scalar = signal_lengths_[k]
            seq_k_scalar = sequence_lengths_[k]
            sig_k_length = Variable(torch.IntTensor([sig_k_scalar]),
                                    volatile=True)
            seq_k_length = Variable(torch.IntTensor([seq_k_scalar]),
                                    volatile=True)
            signal_k = Variable(signals[k, :, :sig_k_scalar].unsqueeze(0),
                                volatile=True)
            sequence_k = Variable(sequences_[k, :seq_k_scalar].unsqueeze(0),
                                  volatile=True)

            # compute transcription output:
            trans_k = network(signal_k).permute(2, 0,
                                                1)  # Permute: 1xDxT => Tx1xD

            # compute normalized CTC loss and accumulate gradient:
            loss = ctc_loss_fn(trans_k, sequence_k.int(), sig_k_length.int(),
                               seq_k_length.int())
            total_loss += loss
            transcriptions_list.append(trans_k)

        # combine transcriptions back into a batch and return:
        max_length = max([t.size(0) for t in transcriptions_list])
        transcriptions = Variable(torch.zeros(max_length, bsz, num_labels),
                                  volatile=True)
        for j, tr in enumerate(transcriptions_list):
            transcriptions[0:tr.size(0), j, :] = tr[:, 0, :]
        return total_loss, transcriptions

    #--- choose appropriate model loss/eval functions depending on command line argument:
    model_loss = async_model_loss if cfg['async'] else batch_model_loss
    model_eval = async_model_eval if cfg['async'] else batch_model_eval

    ### build optimizer and LR scheduler:
    if (cfg['optim'] == 'adamax'):
        opt = optim.Adamax(network.parameters(), lr=cfg['lr'])
    elif (cfg['optim'] == 'adam'):
        opt = optim.Adam(network.parameters(), lr=cfg['lr'])
    else:
        raise Exception("Optimizer not recognized!")
    sched = ReduceLROnPlateau(opt, mode='min', patience=5)
    print("Constructed {} optimizer.".format(cfg['optim']))

    ### build beam search decoder:
    beam_labels = [' ', 'A', 'G', 'C', 'T']
    beam_blank_id = 0
    beam_decoder = CTCBeamDecoder(beam_labels,
                                  beam_width=100,
                                  blank_id=beam_blank_id,
                                  num_processes=4)
    print("Constructed CTC beam search decoder.")

    ### build engine, meters, and hooks:
    engine = Engine()
    loss_meter = tnt.meter.MovingAverageValueMeter(windowsize=5)
    print("Constructed engine. Running training loop...")

    #-- hook: reset all meters
    def reset_all_meters():
        loss_meter.reset()

    #-- hook: don't do anything for now when obtaining a data sample
    def on_sample(state):
        pass

    #-- hook: don't do anything on gradient update for now
    def on_update(state):
        pass

    #-- hook: update loggers at each forward pass
    def on_forward(state):
        loss_meter.add(state['loss'].data[0])
        if (state['t'] % cfg['print_every'] == 0):
            tqdm.write("Step: {0} | Loss: {1}".format(state['t'],
                                                      state['loss'].data[0]),
                       file=cfg['logfile'])

    #-- hook: reset all meters at the start of the epoch
    def on_start_epoch(state):
        reset_all_meters()
        network.train()  # set to training mode for batch norm
        state['iterator'] = tqdm(state['iterator'])

    #-- hook: perform validation and beam-search-decoding at end of each epoch:
    def on_end_epoch(state):
        network.eval()  # set to validation mode for batch-norm
        # K steps of validation; average the loss:
        val_losses = []
        base_seqs = []
        val_data_iterator = get_iterator('valid')
        for k, val_sample in enumerate(val_data_iterator):
            if k > cfg['num_valid_steps']: break
            val_loss, transcriptions = model_eval(val_sample)
            val_losses.append(val_loss.data[0])
            sequences = val_sample[2]
            # mask out the padding & permute (TxBxD => BxTxD):
            scores = mask_padding(transcriptions.permute(1, 0, 2),
                                  val_sample[1],
                                  fill_logit_idx=0)
            logits = F.softmax(scores, dim=2)
            base_seqs.append((sequences, logits))
        avg_val_loss = np.mean(val_losses)
        # log to both logfile and stdout:
        tqdm.write("EPOCH {0} | Avg. Val Loss: {1}".format(
            state['epoch'], avg_val_loss),
                   file=cfg['logfile'])
        print("EPOCH {0} | Avg. Val Loss: {1}".format(state['epoch'],
                                                      avg_val_loss))

        # send average val. loss to learning rate scheduler:
        sched.step(avg_val_loss)

        # beam search decoding:
        # (wrapped in try-excepts to prevent a thrown error from aborting training)
        _nt_dict_ = {0: ' ', 1: 'A', 2: 'G', 3: 'C', 4: 'T'}

        def convert_to_string(toks, voc, num):
            try:
                nt = ''.join([voc[t] for t in toks[0:num]])
            except:
                nt = ''
            return nt

        for true_seqs, logits in base_seqs:
            try:
                true_nts = labels2strings(true_seqs, lookup=_nt_dict_)
                amax_nts = labels2strings(argmax_decode(logits),
                                          lookup=_nt_dict_)
                beam_result, beam_scores, beam_times, beam_lengths = beam_decoder.decode(
                    logits.data)
                pred_nts = [
                    convert_to_string(beam_result[k][0], _nt_dict_,
                                      beam_lengths[k][0])
                    for k in range(len(beam_result))
                ]
                for i in range(min(len(true_nts), len(pred_nts))):
                    tqdm.write("True Seq: {0}".format(true_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("Beam Seq: {0}".format(pred_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("Amax Seq: {0}".format(amax_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write(
                        ("- " * 10 + "Local Beam Alignment" + " -" * 10),
                        file=cfg['logfile'])
                    tqdm.write(ssw(true_nts[i], pred_nts[i]),
                               file=cfg['logfile'])
                    tqdm.write("= " * 40, file=cfg['logfile'])
            except:
                tqdm.write("(WARN: Could not parse batch; skipping...)",
                           file=cfg['logfile'])
                continue

        # save model:
        try:
            mdl_dtype = "cuda" if (cuda_avail and cfg['cuda']) else "cpu"
            mdl_path = os.path.join(
                cfg['save_dir'],
                "ctc_encoder.{0}.{1}.pth".format(state['epoch'], mdl_dtype))
            torch.save(network.state_dict(), mdl_path)
            tqdm.write("Saved model.", file=cfg['logfile'])
        except:
            print("Unable to serialize model; Moving on. Traceback:")
            traceback.print_exc()
            tqdm.write("Unable to serialize models. Moving on...",
                       file=cfg['logfile'])

        # reset all meters for next epoch:
        reset_all_meters()

    ### engine setup & training:
    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.train(model_loss,
                 get_iterator('train'),
                 maxepoch=cfg['max_epochs'],
                 optimizer=opt)
Ejemplo n.º 29
0
    x: torch.utils.data.DataLoader(image_datasets[x],
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)
    for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
file_name = __file__.split('/')[-1].split('.')[0]

#Create model and initialize/freeze weights
model_conv = CRNN(imgH, nc, nclass, nh)
print(model_conv)
model_conv = model_conv.cuda(DEVICE)

#Initialize optimizer and loss function
criterion = CTCLoss()
criterion = criterion.cuda()

if adam:
    optimizer = optim.Adam(model_conv.parameters(),
                           lr=lr,
                           betas=(beta1, 0.999))
elif adadelta:
    optimizer = optim.Adadelta(model_conv.parameters(), lr=lr)
else:
    optimizer = optim.RMSprop(model_conv.parameters(), lr=lr)

#Train model
model_conv = train_model(model_conv,
                         criterion,
                         optimizer,
Ejemplo n.º 30
0
def main():
    args = parser.parse_args()
    cf = ConfigParser.ConfigParser()
    try:
        cf.read(args.conf)
    except:
        print("conf file not exists")

    try:
        seed = cf.get('Training', 'seed')
        seed = long(seed)
    except:
        seed = torch.cuda.initial_seed()

    torch.manual_seed(seed)
    if USE_CUDA:
        torch.cuda.manual_seed_all(seed)

    logger = init_logger(os.path.join(args.log_dir, 'train_ctc_model.log'))

    #Define Model
    rnn_input_size = cf.getint('Model', 'rnn_input_size')
    rnn_hidden_size = cf.getint('Model', 'rnn_hidden_size')
    rnn_layers = cf.getint('Model', 'rnn_layers')
    rnn_type = RNN[cf.get('Model', 'rnn_type')]
    bidirectional = cf.getboolean('Model', 'bidirectional')
    batch_norm = cf.getboolean('Model', 'batch_norm')
    rnn_param = {
        "rnn_input_size": rnn_input_size,
        "rnn_hidden_size": rnn_hidden_size,
        "rnn_layers": rnn_layers,
        "rnn_type": rnn_type,
        "bidirectional": bidirectional,
        "batch_norm": batch_norm
    }

    num_class = cf.getint('Model', 'num_class')
    drop_out = cf.getfloat('Model', 'drop_out')
    add_cnn = cf.getboolean('Model', 'add_cnn')

    cnn_param = {}
    layers = cf.getint('CNN', 'layers')
    channel = eval(cf.get('CNN', 'channel'))
    kernel_size = eval(cf.get('CNN', 'kernel_size'))
    stride = eval(cf.get('CNN', 'stride'))
    padding = eval(cf.get('CNN', 'padding'))
    pooling = eval(cf.get('CNN', 'pooling'))
    batch_norm = cf.getboolean('CNN', 'batch_norm')
    activation_function = activate_f[cf.get('CNN', 'activation_function')]

    cnn_param['batch_norm'] = batch_norm
    cnn_param['activate_function'] = activation_function
    cnn_param["layer"] = []
    for layer in range(layers):
        layer_param = [
            channel[layer], kernel_size[layer], stride[layer], padding[layer]
        ]
        if pooling is not None:
            layer_param.append(pooling[layer])
        else:
            layer_param.append(None)
        cnn_param["layer"].append(layer_param)

    model = CTC_Model(rnn_param=rnn_param,
                      add_cnn=add_cnn,
                      cnn_param=cnn_param,
                      num_class=num_class,
                      drop_out=drop_out)
    #model.apply(xavier_uniform_init)
    for idx, m in enumerate(model.modules()):
        print(idx, m)
        break

    dataset = cf.get('Data', 'dataset')
    data_dir = cf.get('Data', 'data_dir')
    feature_type = cf.get('Data', 'feature_type')
    out_type = cf.get('Data', 'out_type')
    n_feats = cf.getint('Data', 'n_feats')
    mel = cf.getboolean('Data', 'mel')
    batch_size = cf.getint("Training", 'batch_size')

    #Data Loader
    train_dataset = myDataset(data_dir,
                              data_set='train',
                              feature_type=feature_type,
                              out_type=out_type,
                              n_feats=n_feats,
                              mel=mel)
    dev_dataset = myDataset(data_dir,
                            data_set="dev",
                            feature_type=feature_type,
                            out_type=out_type,
                            n_feats=n_feats,
                            mel=mel)
    if add_cnn:
        train_loader = myCNNDataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4,
                                       pin_memory=False)
        dev_loader = myCNNDataLoader(dev_dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=4,
                                     pin_memory=False)
    else:
        train_loader = myDataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=False)
        dev_loader = myDataLoader(dev_dataset,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=4,
                                  pin_memory=False)
    #decoder for dev set
    decoder = GreedyDecoder(dev_dataset.int2phone, space_idx=-1, blank_index=0)

    #Training
    init_lr = cf.getfloat('Training', 'init_lr')
    num_epoches = cf.getint('Training', 'num_epoches')
    end_adjust_acc = cf.getfloat('Training', 'end_adjust_acc')
    decay = cf.getfloat("Training", 'lr_decay')
    weight_decay = cf.getfloat("Training", 'weight_decay')

    params = {
        'num_epoches': num_epoches,
        'end_adjust_acc': end_adjust_acc,
        'mel': mel,
        'seed': seed,
        'decay': decay,
        'learning_rate': init_lr,
        'weight_decay': weight_decay,
        'batch_size': batch_size,
        'feature_type': feature_type,
        'n_feats': n_feats,
        'out_type': out_type
    }
    print(params)

    if USE_CUDA:
        model = model.cuda()

    loss_fn = CTCLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 weight_decay=weight_decay)

    #visualization for training
    from visdom import Visdom
    viz = Visdom()
    if add_cnn:
        title = dataset + ' ' + feature_type + str(n_feats) + ' CNN_LSTM_CTC'
    else:
        title = dataset + ' ' + feature_type + str(n_feats) + ' LSTM_CTC'

    opts = [
        dict(title=title + " Loss", ylabel='Loss', xlabel='Epoch'),
        dict(title=title + " Loss on Dev", ylabel='DEV Loss', xlabel='Epoch'),
        dict(title=title + ' CER on DEV', ylabel='DEV CER', xlabel='Epoch')
    ]
    viz_window = [None, None, None]

    count = 0
    learning_rate = init_lr
    loss_best = 1000
    loss_best_true = 1000
    adjust_rate_flag = False
    stop_train = False
    adjust_time = 0
    acc_best = 0
    acc_best_true = 0
    start_time = time.time()
    loss_results = []
    dev_loss_results = []
    dev_cer_results = []

    while not stop_train:
        if count >= num_epoches:
            break
        count += 1

        if adjust_rate_flag:
            learning_rate *= decay
            adjust_rate_flag = False
            for param in optimizer.param_groups:
                param['lr'] *= decay

        print("Start training epoch: %d, learning_rate: %.5f" %
              (count, learning_rate))
        logger.info("Start training epoch: %d, learning_rate: %.5f" %
                    (count, learning_rate))

        loss = train(model,
                     train_loader,
                     loss_fn,
                     optimizer,
                     logger,
                     add_cnn=add_cnn,
                     print_every=20)
        loss_results.append(loss)
        acc, dev_loss = dev(model,
                            dev_loader,
                            loss_fn,
                            decoder,
                            logger,
                            add_cnn=add_cnn)
        print("loss on dev set is %.4f" % dev_loss)
        logger.info("loss on dev set is %.4f" % dev_loss)
        dev_loss_results.append(dev_loss)
        dev_cer_results.append(acc)

        #adjust learning rate by dev_loss
        if dev_loss < (loss_best - end_adjust_acc):
            loss_best = dev_loss
            adjust_rate_count = 0
            model_state = copy.deepcopy(model.state_dict())
            op_state = copy.deepcopy(optimizer.state_dict())
        elif (dev_loss < loss_best + end_adjust_acc):
            adjust_rate_count += 1
            if dev_loss < loss_best and dev_loss < loss_best_true:
                loss_best_true = dev_loss
                model_state = copy.deepcopy(model.state_dict())
                op_state = copy.deepcopy(optimizer.state_dict())
        else:
            adjust_rate_count = 10

        if acc > acc_best:
            acc_best = acc
            best_model_state = copy.deepcopy(model.state_dict())
            best_op_state = copy.deepcopy(optimizer.state_dict())
        '''
        #adjust learning rate by dev_acc
        if acc > (acc_best + end_adjust_acc):
            acc_best = acc
            adjust_rate_count = 0
            loss_best = dev_loss
            model_state = copy.deepcopy(model.state_dict())
            op_state = copy.deepcopy(optimizer.state_dict())
        elif (acc > acc_best - end_adjust_acc):
            adjust_rate_count += 1
            if acc > acc_best and acc > acc_best_true:
                acc_best_true = acc
                loss_best = dev_loss
                model_state = copy.deepcopy(model.state_dict())
                op_state = copy.deepcopy(optimizer.state_dict())
        else:
            adjust_rate_count = 0
        #torch.save(model.state_dict(), model_path_reject)
        '''

        print("adjust_rate_count:" + str(adjust_rate_count))
        print('adjust_time:' + str(adjust_time))
        logger.info("adjust_rate_count:" + str(adjust_rate_count))
        logger.info('adjust_time:' + str(adjust_time))

        if adjust_rate_count == 10:
            adjust_rate_flag = True
            adjust_time += 1
            adjust_rate_count = 0
            if loss_best > loss_best_true:
                loss_best = loss_best_true
            #if acc_best < acc_best_true:
            #    acc_best = acc_best_true
            model.load_state_dict(model_state)
            optimizer.load_state_dict(op_state)

        if adjust_time == 8:
            stop_train = True

        time_used = (time.time() - start_time) / 60
        print("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
              (count, acc, time_used))
        logger.info("epoch %d done, cv acc is: %.4f, time_used: %.4f minutes" %
                    (count, acc, time_used))

        x_axis = range(count)
        y_axis = [
            loss_results[0:count], dev_loss_results[0:count],
            dev_cer_results[0:count]
        ]
        for x in range(len(viz_window)):
            if viz_window[x] is None:
                viz_window[x] = viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    opts=opts[x],
                )
            else:
                viz.line(
                    X=np.array(x_axis),
                    Y=np.array(y_axis[x]),
                    win=viz_window[x],
                    update='replace',
                )

    print("End training, best cv loss is: %.4f, acc is: %.4f" %
          (loss_best, acc_best))
    logger.info("End training, best loss acc is: %.4f, acc is: %.4f" %
                (loss_best, acc_best))
    model.load_state_dict(best_model_state)
    optimizer.load_state_dict(best_op_state)
    best_path = os.path.join(args.log_dir,
                             'best_model' + '_cv' + str(acc_best) + '.pkl')
    cf.set('Model', 'model_file', best_path)
    cf.write(open(args.conf, 'w'))
    params['epoch'] = count

    torch.save(
        CTC_Model.save_package(model,
                               optimizer=optimizer,
                               epoch=params,
                               loss_results=loss_results,
                               dev_loss_results=dev_loss_results,
                               dev_cer_results=dev_cer_results), best_path)