示例#1
0
 def drop_check_point(self, checkpoint_name, args):
     try:
         utils.save_checkpoint(
             {
                 'epoch': args.start_epoch,
                 'state_dict': self.ner_model.state_dict(),
                 'optimizer': self.optimizer.state_dict()
             }, {
                 'track_list': self.track_list,
                 'args': vars(args)
             }, checkpoint_name)
         self.checkpoint_name = checkpoint_name
     except Exception as inst:
         print(inst)
示例#2
0
 def save_model(self, file):
     #print("saving model")
     utils.save_checkpoint(
         {
             'epoch': self.args.start_epoch,
             'state_dict': self.ner_model.state_dict(),
             'optimizer': self.optimizer.state_dict(),
             'f_map': self.f_map,
             'l_map': self.l_map,
             'c_map': self.char_map,
             'in_doc_words': self.in_doc_words
         }, {
             'track_list': self.track_list,
             'args': vars(self.args)
         }, self.args.checkpoint + 'cwlm_lstm_crf')
示例#3
0
        start_time = time.time()

        log.info(f'Epoch {epoch+1} training')
        train_loss = train(model, device, training_loader, optimizer,
                           criterion, clip)
        log.info(f'\nEpoch {epoch + 1} validation')
        valid_loss, bleu_score = eval(model, device, valid_loader, criterion)

        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        #     if valid_loss < best_valid_loss:
        #         best_valid_loss = valid_loss
        save_checkpoint(model_path / stage / f'decoder/model0epoch{epoch}',
                        epoch, model, optimizer, valid_loss_list,
                        train_loss_list)

        log.info(
            f'\nEpoch: {epoch + 1:02} completed | Time: {epoch_mins}m {epoch_secs}s'
        )
        log.info(
            f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}'
        )
        log.info(
            f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} | Val. BLEU {bleu_score}\n\n'
        )
示例#4
0
def train(args):
    if args.gpu > 0 and torch.cuda.is_available():
        cvd = use_single_gpu()
        print(f"GPU {cvd} is used")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_set = DurationDataset(duration_file=args.train,
                             max_len=args.max_len)

    dev_set = DurationDataset(duration_file=args.val,
                             max_len=args.max_len)

    collate_fn = DurationCollator(args.max_len, model_type=args.model_type, context=args.context)
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=True)
    dev_loader = torch.utils.data.DataLoader(dataset=dev_set,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=True)
    # prepare model
    if args.model_type == "Transformer":
        model = TransformerDuration(dim_feedforward=args.dim_feedforward,
                                phone_size=args.phone_size,
                                embed_size=args.embedding_size,
                                d_model=args.hidden_size,
                                dropout=args.dropout,
                                d_output=1,
                                nhead=args.nhead,
                                num_block=args.num_block,
                                local_gaussian=args.local_gaussian,
                                pos_enc=True)
    elif args.model_type == "LSTM":
        model = LSTMDuration(phone_size=args.phone_size,
                                embed_size=args.embedding_size,
                                d_model=args.hidden_size,
                                dropout=args.dropout,
                                d_output=1,
                                num_block=args.num_block)
    elif args.model_type == "DNN":
        model = DNNDuration(input_size=args.context * 2 + 1,
                                d_model = args.hidden_size,
                                d_output=1)
    else:
        raise ValueError('Not Support Model Type %s' % args.model_type)
    print(model)
    model = model.to(device)

    model_load_dir = ""
    start_epoch = 1
    if args.initmodel != '':
        model_load_dir = args.initmodel
    if args.resume:
        checks = os.listdir(args.model_save_dir)
        start_epoch = max(list(map(lambda x: int(x[6:-8]) if x.endswith("pth.tar") else -1, checks)))
        model_load_dir = "{}/epoch_{}.pth.tar".format(args.model_save_dir, start_epoch)


    # load weights for pre-trained model
    if model_load_dir != '':
        model_load = torch.load(model_load_dir, map_location=device)
        loading_dict = model_load['state_dict']
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        for k, v in loading_dict.items():
            assert k in model_dict
            if model_dict[k].size() == loading_dict[k].size():
                state_dict_new[k] = v
            else:
                para_list.append(k)
        print("Total {} parameters, Loaded {} parameters".format(
            len(loading_dict), len(state_dict_new)))
        if len(para_list) > 0:
            print("Not loading {} because of different sizes".format(
                ", ".join(para_list)))
        model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        print("Loaded checkpoint {}".format(args.initmodel))
        print("")

    # setup optimizer
    if args.optimizer == 'noam':
        optimizer = ScheduledOptim(torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.98),
            eps=1e-09),
            args.hidden_size,
            args.noam_warmup_steps,
            args.noam_scale)
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.98),
            eps=1e-09)
    else:
        raise ValueError('Not Support Optimizer')

    # Setup tensorborad logger
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter("{}/log".format(args.model_save_dir))
    else:
        logger = None

    if args.loss == "l1":
        loss = MaskedLoss("l1")
    elif args.loss == "mse":
        loss = MaskedLoss("mse")
    else:
        raise ValueError("Not Support Loss Type")
    
    # Training
    for epoch in range(1 + start_epoch, 1 + args.max_epochs):
        start_t_train = time.time()
        train_info = train_one_epoch(train_loader, model, device, optimizer, loss, args)
        end_t_train = time.time()
        if args.optimizer == "noam":
            print(
            'Train epoch: {:04d}, lr: {:.6f}, '
            'loss: {:.4f}, time: {:.2f}s'.format(
                epoch, optimizer._optimizer.param_groups[0]['lr'],
                train_info['loss'], end_t_train - start_t_train))
        else:
            print(
            'Train epoch: {:04d}, '
            'loss: {:.4f}, time: {:.2f}s'.format(
                epoch,
                train_info['loss'], end_t_train - start_t_train))

        start_t_dev = time.time()
        dev_info = validate(dev_loader, model, device, loss, args)
        end_t_dev = time.time()

        print("Valid loss: {:.4f}, time: {:.2f}s".format(
            dev_info['loss'], end_t_dev - start_t_dev))
        print("")
        sys.stdout.flush()
        
        if not os.path.exists(args.model_save_dir):
            os.makedirs(args.model_save_dir)
        if args.optimizer == "noam":
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer._optimizer.state_dict(),
            }, "{}/epoch_{}.pth.tar".format(args.model_save_dir, epoch))
        else:
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
            }, "{}/epoch_{}.pth.tar".format(args.model_save_dir, epoch))

        # record training and validation information
        if args.use_tfboard:
            record_info(train_info, dev_info, epoch, logger)

    if args.use_tfboard:
        logger.close()
示例#5
0
                print(
                    '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f, F1 on test = %.4f, acc on test= %.4f), saving...' %
                    (epoch_loss,
                     args.start_epoch,
                     dev_f1,
                     dev_acc,
                     test_f1,
                     test_acc))

                try:
                    utils.save_checkpoint({
                        'epoch': args.start_epoch,
                        'state_dict': ner_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'f_map': f_map,
                        'l_map': l_map,
                        'c_map': c_map,
                        'in_doc_words': in_doc_words
                    }, {'track_list': track_list,
                        'args': vars(args)
                        }, args.checkpoint + 'cwlm_lstm_crf')
                except Exception as inst:
                    print(inst)

            else:
                patience_count += 1
                print('(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f)' %
                      (epoch_loss,
                       args.start_epoch,
                       dev_f1,
                       dev_acc))
示例#6
0
            })

            print(
                '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f, F1 on test = %.4f, pre on test = %.4f, rec on test = %.4f), saving...'
                % (epoch_loss, args.start_epoch, dev_f1, dev_pre, dev_rec,
                   test_f1, test_pre, test_rec))

            try:
                utils.save_checkpoint(
                    {
                        'epoch': args.start_epoch,
                        'state_dict': ner_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'f_map': f_map,
                        'l_map': l_map,
                        'a_map': a_map,
                        'ner_map': ner_map,
                        'char_map': char_map,
                        'singleton': singleton
                    }, {
                        'track_list': track_list,
                        'args': vars(args)
                    },
                    args.checkpoint + '/dev=' + str(round(best_f1 * 100, 2)))
            except Exception as inst:
                print(inst)

        else:
            patience_count += 1
            print('(loss: %.4f, epoch: %d, dev F1 = %.4f)' %
                  (epoch_loss, args.start_epoch, dev_f1))
            track_list.append({'loss': epoch_loss, 'dev_f1': dev_f1})
示例#7
0
def train(args):
    if args.gpu > 0 and torch.cuda.is_available():
        cvd = use_single_gpu()
        print(f"GPU {cvd} is used")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_set = SVSDataset(align_root_path=args.train_align,
                           pitch_beat_root_path=args.train_pitch,
                           wav_root_path=args.train_wav,
                           char_max_len=args.char_max_len,
                           max_len=args.num_frames,
                           sr=args.sampling_rate,
                           preemphasis=args.preemphasis,
                           frame_shift=args.frame_shift,
                           frame_length=args.frame_length,
                           n_mels=args.n_mels,
                           power=args.power,
                           max_db=args.max_db,
                           ref_db=args.ref_db)

    dev_set = SVSDataset(align_root_path=args.val_align,
                         pitch_beat_root_path=args.val_pitch,
                         wav_root_path=args.val_wav,
                         char_max_len=args.char_max_len,
                         max_len=args.num_frames,
                         sr=args.sampling_rate,
                         preemphasis=args.preemphasis,
                         frame_shift=args.frame_shift,
                         frame_length=args.frame_length,
                         n_mels=args.n_mels,
                         power=args.power,
                         max_db=args.max_db,
                         ref_db=args.ref_db)
    collate_fn_svs = SVSCollator(args.num_frames, args.char_max_len)
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn_svs,
                                               pin_memory=True)
    dev_loader = torch.utils.data.DataLoader(dataset=dev_set,
                                             batch_size=args.batchsize,
                                             shuffle=True,
                                             num_workers=args.num_workers,
                                             collate_fn=collate_fn_svs,
                                             pin_memory=True)
    # print(dev_set[0][3].shape)
    assert args.feat_dim == dev_set[0][3].shape[1]

    # prepare model
    if args.model_type == "GLU_Transformer":
        model = GLU_Transformer(phone_size=args.phone_size,
                                embed_size=args.embedding_size,
                                hidden_size=args.hidden_size,
                                glu_num_layers=args.glu_num_layers,
                                dropout=args.dropout,
                                output_dim=args.feat_dim,
                                dec_nhead=args.dec_nhead,
                                dec_num_block=args.dec_num_block,
                                device=device)
    else:
        raise ValueError('Not Support Model Type %s' % args.model_type)
    print(model)
    model = model.to(device)

    # load weights for pre-trained model
    if args.initmodel != '':
        pretrain = torch.load(args.initmodel, map_location=device)
        pretrain_dict = pretrain['state_dict']
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        for k, v in pretrain_dict.items():
            assert k in model_dict
            if model_dict[k].size() == pretrain_dict[k].size():
                state_dict_new[k] = v
            else:
                para_list.append(k)
        print("Total {} parameters, Loaded {} parameters".format(
            len(pretrain_dict), len(state_dict_new)))
        if len(para_list) > 0:
            print("Not loading {} because of different sizes".format(
                ", ".join(para_list)))
        model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        print("Loaded checkpoint {}".format(args.initmodel))
        print("")

    # setup optimizer
    if args.optimizer == 'noam':
        optimizer = ScheduledOptim(
            torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09), args.hidden_size,
            args.noam_warmup_steps, args.noam_scale)
    else:
        raise ValueError('Not Support Optimizer')

    # Setup tensorborad logger
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter("{}/log".format(args.model_save_dir))
    else:
        logger = None

    if args.loss == "l1":
        loss = MaskedLoss("l1")
    elif args.loss == "mse":
        loss = MaskedLoss("mse")
    else:
        raise ValueError("Not Support Loss Type")

    # Training
    for epoch in range(1, 1 + args.max_epochs):
        start_t_train = time.time()
        train_info = train_one_epoch(train_loader, model, device, optimizer,
                                     loss, args)
        end_t_train = time.time()

        print('Train epoch: {:04d}, lr: {:.6f}, '
              'loss: {:.4f}, time: {:.2f}s'.format(
                  epoch, optimizer._optimizer.param_groups[0]['lr'],
                  train_info['loss'], end_t_train - start_t_train))

        start_t_dev = time.time()
        dev_info = validate(dev_loader, model, device, loss, args)
        end_t_dev = time.time()

        print("Epoch: {:04d}, Valid loss: {:.4f}, time: {:.2f}s".format(
            epoch, dev_info['loss'], end_t_dev - start_t_dev))
        print("")
        sys.stdout.flush()

        if not os.path.exists(args.model_save_dir):
            os.makedirs(args.model_save_dir)

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer._optimizer.state_dict(),
            }, "{}/epoch_{}.pth.tar".format(args.model_save_dir, epoch))

        # record training and validation information
        if args.use_tfboard:
            record_info(train_info, dev_info, epoch, logger)

    if args.use_tfboard:
        logger.close()
示例#8
0
def train_one_epoch(model,
                    optimizer,
                    scheduler,
                    data_loader,
                    device,
                    epoch,
                    output_dir,
                    tensorboard=False,
                    print_freq=10):
    """
    defining one epoch of train

    :param model: (nn.Module): instance of model
    :param optimizer: (nn.Module): instance of optimizer
    :param scheduler: object scheduler
    :param data_loader: object dataloader
    :param device: str, faster-rcnn works only GPU
    :param epoch: int, number of epoch
    :param output_dir: directory where to save state of model and log files
    :param tensorboard: if true: save to output_dir log files after each epoch
    :param print_freq: int, after how many iteration print statistic
    """
    if tensorboard:
        logger = SummaryWriter(output_dir)

    # set model to train mode
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    # through all data
    for images, targets in metric_logger.log_every(data_loader, print_freq,
                                                   header):
        # images and targets to GPU
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # forward through model, model return a dictionary of losses
        # loss_box_reg, loss_classifier, loss_objectness, loss_rpn_box_reg
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        # set all gradient by zero
        optimizer.zero_grad()
        # backward through model
        losses.backward()
        # make gradient step
        optimizer.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # create checkpoint after each epoch
    save_name = os.path.join(output_dir, 'faster_rcnn_{}.pth'.format(epoch))
    save_checkpoint(
        {
            'start_epoch': epoch + 1,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'losses': loss_value
        }, save_name)
    print('save model: {}'.format(save_name))

    # save metric after each epoch
    if tensorboard:
        logger.add_scalars('train/losses', loss_dict)
        logger.add_scalar('train/loss_value', losses)
        logger.close()
    print('metric save {}'.format(output_dir))
示例#9
0
                       dev_rec, test_f1, test_pre, test_rec))

                if args.output_annotation:  #NEW
                    print('annotating')
                    with open('output' + str(file_no) + '.txt', 'w') as fout:
                        predictor.output_batch(ner_model, test_word[file_no],
                                               fout, file_no)

                try:
                    utils.save_checkpoint(
                        {
                            'epoch': args.start_epoch,
                            'state_dict': ner_model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'f_map': f_map,
                            'l_map': l_map,
                            'c_map': char_map,
                            'in_doc_words': in_doc_words
                        }, {
                            'track_list': track_list,
                            'args': vars(args)
                        }, args.checkpoint + 'cwlm_lstm_crf')
                except Exception as inst:
                    print(inst)

            else:
                patience_count += 1
                print(
                    '(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)'
                    % (epoch_loss, args.start_epoch, file_no, dev_f1, dev_pre,
                       dev_rec))
示例#10
0
                print(
                    '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f, F1 on test = %.4f, acc on test= %.4f), saving...' %
                    (epoch_loss,
                     args.start_epoch,
                     dev_f1,
                     dev_acc,
                     test_f1,
                     test_acc))

                try:
                    utils.save_checkpoint({
                        'epoch': args.start_epoch,
                        'state_dict': ner_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'f_map': f_map,
                        'l_map': l_map,
                    }, {'track_list': track_list,
                        'args': vars(args)
                        }, args.checkpoint + '_lstm')
                except Exception as inst:
                    print(inst)

            else:
                patience_count += 1
                print('(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f)' %
                      (epoch_loss,
                       args.start_epoch,
                       dev_f1,
                       dev_acc))
                track_list.append({'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc})
示例#11
0
                {'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc, 'test_f1': test_f1,
                 'test_acc': test_acc})

            print '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f, F1 on test = %.4f, acc on test= %.4f), saving...' % (epoch_loss,
                 cur_epoch,
                 dev_f1,
                 dev_acc,
                 test_f1,
                 test_acc)

            try:
                utils.save_checkpoint({
                    'epoch': cur_epoch,
                    'state_dict': ner_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'w_map': w_map,
                    'l_map': l_map,
                    'c_map': c_map,
                }, {'track_list': track_list,
                    'args': vars(args)
                    }, args.checkpoint + 'lstm_ti')
            except Exception as inst:
                print(inst)

        else:
            patience_count += 1
            print '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev acc = %.4f)' % (epoch_loss,
                   cur_epoch,
                   dev_f1,
                   dev_acc)
            track_list.append({'loss': epoch_loss, 'dev_f1': dev_f1, 'dev_acc': dev_acc})
         'test_pre': test_pre,
         'test_rec': test_rec,
         'test_acc': test_acc
     })
     print(
         '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f, dev acc = %.4f, F1 on test = %.4f, pre on test= %.4f, rec on test= %.4f, acc on test= %.4f), saving...'
         % (epoch_loss, args.start_epoch, dev_f1, dev_pre, dev_rec,
            dev_acc, test_f1, test_pre, test_rec, test_acc))
     try:
         utils.save_checkpoint(
             {
                 'epoch': args.start_epoch,
                 'state_dict': ner_model.state_dict(),
                 'optimizer': optimizer.state_dict(),
                 'f_map': f_map,
                 'lexicon_f_map': lexicon_f_map,
                 'bichar_f_map': bichar_f_map,
                 'l_map': l_map,
                 'bichar': args.bichar,
             }, {
                 'track_list': track_list,
                 'args': vars(args)
             }, args.checkpoint + 'lattice_word_seg')
     except Exception as inst:
         print(inst)
 else:
     patience_count += 1
     print(
         '(loss: %.4f, epoch: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f, dev acc = %.4f)'
         % (epoch_loss, args.start_epoch, dev_f1, dev_pre, dev_rec,
            dev_acc))
     track_list.append({
示例#13
0
def main(ap: argparse.ArgumentParser):
    """
    Main runner to execute training/validation of model.

    :param ap: Argument parser holding all command line input
    """
    ap.add_argument("--mode",
                    type=str,
                    default="train",
                    help="Mode to run from command line")
    ap.add_argument("--lr",
                    type=float,
                    default=2e-5,
                    help="Learning rate for optimizer")
    ap.add_argument("--eps",
                    type=float,
                    default=1e-8,
                    help="Epsilon for optimizer")
    ap.add_argument("--epochs",
                    type=int,
                    default=5,
                    help="Number of epochs to train for")
    ap.add_argument("--batchsize",
                    type=int,
                    default=5,
                    help="Batch size for data iteration")
    ap.add_argument("--hiddensize",
                    type=int,
                    default=256,
                    help="Size for hidden layer")
    ap.add_argument("--traindata",
                    type=str,
                    default="Computer Science, BS-General",
                    help="Track for train data")
    ap.add_argument("--validdata",
                    type=str,
                    default="Computer Science, BS-General",
                    help="Track for validation data")
    ap.add_argument("--testdata",
                    type=str,
                    default="Computer Science, BS-General",
                    help="Track for test data")
    ap.add_argument("--device",
                    type=str,
                    default="cpu",
                    help="Device to train on")
    args = ap.parse_args()

    engine = loadEngine()

    if 'train' in args.mode:

        # if one track doesn't exist, fill it in with other
        # if both are different, do not continue
        train_track = args.traindata
        valid_track = args.validdata
        if train_track != valid_track:
            if not train_track:
                train_track = valid_track
            elif not valid_track:
                valid_track = train_track
            else:
                raise ValueError('Invalid track argument.')
        # NOTE: data should be formatted in (inputs, expected) format
        # TODO: get data and feed it in

        # load in data based on track
        train_dataset = TrackDataset(engine, train_track)
        valid_dataset = TrackDataset(engine, valid_track)

        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=args.batchsize,
                                                   shuffle=True)

        valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                                   batch_size=args.batchsize,
                                                   shuffle=False)

        model = setup_model(engine, args.traindata, args.hiddensize)
        loss, total_time = train(args, model, train_loader, valid_loader)

        save_checkpoint('saved_models/' + train_track + '-model', model,
                        args.epochs, loss, total_time)

    elif 'predict' in args.mode:
        # This predict option is only to be used for inference
        # Do not use this in the actual application
        # NOTE: data should be formatted in (inputs) format
        test_track = args.testdata
        test_dataset = TrackDataset(engine, test_track)
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=1,
                                                  shuffle=False)

        predict(model, args, test_loader)
                '(loss: %.4f, epoch: %d, WS: dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f, F1 on test = %.4f, pre on test = %.4f, rec on test = %.4f), saving...'
                % (epoch_loss, args.start_epoch, ws_dev_f1, ws_dev_pre,
                   ws_dev_rec, ws_test_f1, ws_test_pre, ws_test_rec))

            print(
                '(loss: %.4f, epoch: %d, POS: dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f, F1 on test = %.4f, pre on test = %.4f, rec on test = %.4f), saving...'
                % (epoch_loss, args.start_epoch, pos_dev_f1, pos_dev_pre,
                   pos_dev_rec, pos_test_f1, pos_test_pre, pos_test_rec))

            try:
                utils.save_checkpoint(
                    {
                        'epoch': args.start_epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'alphabet': alphabet,
                        'static_alphabet': static_alphabet,
                    }, {
                        'track_list': track_list,
                        'args': vars(args)
                    }, args.checkpoint + 'ws_pos')
            except Exception as inst:
                print(inst)

        else:
            patience_count += 1
            print(
                '(loss: %.4f, epoch: %d, ws dev F1 = %.4f, pos dev F1 = %.4f)'
                % (epoch_loss, args.start_epoch, ws_dev_f1, pos_dev_f1))
            track_list.append({
                'loss': epoch_loss,
                     test_f1,
                     test_pre,
                     test_rec))

                if args.output_annotation: #NEW
                    print('annotating')
                    with open(output_directory + 'output'+str(file_no)+'.txt', 'w') as fout:
                        predictor_list[file_no].output_batch(ner_model, test_word[file_no], fout, file_no)

                try:
                    utils.save_checkpoint({
                        'epoch': args.start_epoch,
                        'state_dict': ner_model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'f_map': f_map,
                        'l_map': label_maps[file_no],
                        'c_map': char_map,
                        'in_doc_words': in_doc_words
                    }, {'track_list': track_list,
                        'args': vars(args)
                        }, args.checkpoint + 'cwlm_lstm_crf' + str(file_no))
                except Exception as inst:
                    print(inst)

            else:
                patience_count += 1
                print('(loss: %.4f, epoch: %d, dataset: %d, dev F1 = %.4f, dev pre = %.4f, dev rec = %.4f)' %
                      (epoch_loss,
                       args.start_epoch,
                       file_no,
                       dev_f1,
     
     ### training
     train(epoch, opt)
     
     ### validate
     if epoch % opt.val_every_epoch == 0:
         print('\n------------ VALIDATION START ------------\n')
         start = time.time()
         bleu4, lang_stats = validate(opt)
         end = time.time()
         print('\nepoch {}, time = {:.3f}, BLEU-4(nltk) = {:.4f}'.format(epoch, end - start, bleu4))
         for m,s in lang_stats.items():
             print('\t%s: %.3f'%(m, s))
         
         current_score = lang_stats['CIDEr']
         
             
         # Check if there was an improvement
         is_best = current_score > val_best_score
         val_best_score = max(current_score, val_best_score)
         if not is_best:
             epochs_since_improvement += 1
             print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement))
         else:
             epochs_since_improvement = 0
 
         # Save checkpoint
         utils.save_checkpoint(opt.save_dir, epoch, epochs_since_improvement, current_score, model, optimizer,
                               opt, is_best)
         print('\n------------ VALIDATION END ------------\n')
     
示例#17
0
def train(args):
    if args.gpu > 0 and torch.cuda.is_available():
        cvd = use_single_gpu()
        print(f"GPU {cvd} is used")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_set = SVSDataset(align_root_path=args.train_align,
                           pitch_beat_root_path=args.train_pitch,
                           wav_root_path=args.train_wav,
                           char_max_len=args.char_max_len,
                           max_len=args.num_frames,
                           sr=args.sampling_rate,
                           preemphasis=args.preemphasis,
                           nfft=args.nfft,
                           frame_shift=args.frame_shift,
                           frame_length=args.frame_length,
                           n_mels=args.n_mels,
                           power=args.power,
                           max_db=args.max_db,
                           ref_db=args.ref_db,
                           sing_quality=args.sing_quality,
                           standard=args.standard)

    dev_set = SVSDataset(align_root_path=args.val_align,
                         pitch_beat_root_path=args.val_pitch,
                         wav_root_path=args.val_wav,
                         char_max_len=args.char_max_len,
                         max_len=args.num_frames,
                         sr=args.sampling_rate,
                         preemphasis=args.preemphasis,
                         nfft=args.nfft,
                         frame_shift=args.frame_shift,
                         frame_length=args.frame_length,
                         n_mels=args.n_mels,
                         power=args.power,
                         max_db=args.max_db,
                         ref_db=args.ref_db,
                         sing_quality=args.sing_quality,
                         standard=args.standard)
    collate_fn_svs = SVSCollator(args.num_frames, args.char_max_len,
                                 args.use_asr_post, args.phone_size,
                                 args.n_mels)
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               collate_fn=collate_fn_svs,
                                               pin_memory=True)
    dev_loader = torch.utils.data.DataLoader(dataset=dev_set,
                                             batch_size=args.batchsize,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             collate_fn=collate_fn_svs,
                                             pin_memory=True)
    # print(dev_set[0][3].shape)
    assert args.feat_dim == dev_set[0][3].shape[1]
    if args.collect_stats:
        collect_stats(train_loader, args)
        print(f"collect_stats finished !")
        quit()
    # prepare model
    if args.model_type == "GLU_Transformer":
        model = GLU_TransformerSVS(phone_size=args.phone_size,
                                   embed_size=args.embedding_size,
                                   hidden_size=args.hidden_size,
                                   glu_num_layers=args.glu_num_layers,
                                   dropout=args.dropout,
                                   output_dim=args.feat_dim,
                                   dec_nhead=args.dec_nhead,
                                   dec_num_block=args.dec_num_block,
                                   n_mels=args.n_mels,
                                   double_mel_loss=args.double_mel_loss,
                                   local_gaussian=args.local_gaussian,
                                   device=device)
    elif args.model_type == "LSTM":
        model = LSTMSVS(phone_size=args.phone_size,
                        embed_size=args.embedding_size,
                        d_model=args.hidden_size,
                        num_layers=args.num_rnn_layers,
                        dropout=args.dropout,
                        d_output=args.feat_dim,
                        n_mels=args.n_mels,
                        device=device,
                        use_asr_post=args.use_asr_post)
    elif args.model_type == "PureTransformer":
        model = TransformerSVS(phone_size=args.phone_size,
                               embed_size=args.embedding_size,
                               hidden_size=args.hidden_size,
                               glu_num_layers=args.glu_num_layers,
                               dropout=args.dropout,
                               output_dim=args.feat_dim,
                               dec_nhead=args.dec_nhead,
                               dec_num_block=args.dec_num_block,
                               n_mels=args.n_mels,
                               double_mel_loss=args.double_mel_loss,
                               local_gaussian=args.local_gaussian,
                               device=device)
    elif args.model_type == "PureTransformer_noGLU_norm":
        model = Transformer_noGLUSVS_norm(stats_file=args.stats_file,
                                          stats_mel_file=args.stats_mel_file,
                                          phone_size=args.phone_size,
                                          embed_size=args.embedding_size,
                                          hidden_size=args.hidden_size,
                                          glu_num_layers=args.glu_num_layers,
                                          dropout=args.dropout,
                                          output_dim=args.feat_dim,
                                          dec_nhead=args.dec_nhead,
                                          dec_num_block=args.dec_num_block,
                                          n_mels=args.n_mels,
                                          double_mel_loss=args.double_mel_loss,
                                          local_gaussian=args.local_gaussian,
                                          device=device)
    elif args.model_type == "PureTransformer_norm":
        model = TransformerSVS_norm(stats_file=args.stats_file,
                                    stats_mel_file=args.stats_mel_file,
                                    phone_size=args.phone_size,
                                    embed_size=args.embedding_size,
                                    hidden_size=args.hidden_size,
                                    glu_num_layers=args.glu_num_layers,
                                    dropout=args.dropout,
                                    output_dim=args.feat_dim,
                                    dec_nhead=args.dec_nhead,
                                    dec_num_block=args.dec_num_block,
                                    n_mels=args.n_mels,
                                    double_mel_loss=args.double_mel_loss,
                                    local_gaussian=args.local_gaussian,
                                    device=device)
    elif args.model_type == "GLU_Transformer_norm":
        model = GLU_TransformerSVS_norm(stats_file=args.stats_file,
                                        stats_mel_file=args.stats_mel_file,
                                        phone_size=args.phone_size,
                                        embed_size=args.embedding_size,
                                        hidden_size=args.hidden_size,
                                        glu_num_layers=args.glu_num_layers,
                                        dropout=args.dropout,
                                        output_dim=args.feat_dim,
                                        dec_nhead=args.dec_nhead,
                                        dec_num_block=args.dec_num_block,
                                        n_mels=args.n_mels,
                                        double_mel_loss=args.double_mel_loss,
                                        local_gaussian=args.local_gaussian,
                                        device=device)

    else:
        raise ValueError('Not Support Model Type %s' % args.model_type)
    print(model)
    model = model.to(device)

    model_load_dir = ""
    pretrain_encoder_dir = ""
    start_epoch = 1
    if args.pretrain_encoder != '':
        pretrain_encoder_dir = args.pretrain_encoder
    if args.initmodel != '':
        model_load_dir = args.initmodel
    if args.resume and os.path.exists(args.model_save_dir):
        checks = os.listdir(args.model_save_dir)
        start_epoch = max(
            list(
                map(lambda x: int(x[6:-8])
                    if x.endswith("pth.tar") else -1, checks)))
        if start_epoch < 0:
            model_load_dir = ""
        else:
            model_load_dir = "{}/epoch_{}.pth.tar".format(
                args.model_save_dir, start_epoch)

    # load encoder parm from Transformer-TTS
    if pretrain_encoder_dir != '':
        pretrain = torch.load(pretrain_encoder_dir, map_location=device)
        pretrain_dict = pretrain['model']
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        i = 0
        for k, v in pretrain_dict.items():
            k_new = k[7:]
            if k_new in model_dict and model_dict[k_new].size(
            ) == pretrain_dict[k].size():
                i += 1
                state_dict_new[k_new] = v
            model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        print(f"Load {i} layers total. Load pretrain encoder success !")

    # load weights for pre-trained model
    if model_load_dir != '':
        model_load = torch.load(model_load_dir, map_location=device)
        loading_dict = model_load['state_dict']
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        for k, v in loading_dict.items():
            assert k in model_dict
            if model_dict[k].size() == loading_dict[k].size():
                state_dict_new[k] = v
            else:
                para_list.append(k)
        print("Total {} parameters, Loaded {} parameters".format(
            len(loading_dict), len(state_dict_new)))
        if len(para_list) > 0:
            print("Not loading {} because of different sizes".format(
                ", ".join(para_list)))
        model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        print("Loaded checkpoint {}".format(args.initmodel))
        print("")

    # setup optimizer
    if args.optimizer == 'noam':
        optimizer = ScheduledOptim(
            torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09), args.hidden_size,
            args.noam_warmup_steps, args.noam_scale)
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)
    else:
        raise ValueError('Not Support Optimizer')

    # Setup tensorborad logger
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter("{}/log".format(args.model_save_dir))
    else:
        logger = None

    if args.loss == "l1":
        loss = MaskedLoss("l1", mask_free=args.mask_free)
    elif args.loss == "mse":
        loss = MaskedLoss("mse", mask_free=args.mask_free)
    else:
        raise ValueError("Not Support Loss Type")

    if args.perceptual_loss > 0:
        win_length = int(args.sampling_rate * args.frame_length)
        psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate,
                                               win_len=win_length)
        sf = cal_spread_function(bark_num)
        loss_perceptual_entropy = PerceptualEntropy(bark_num, sf,
                                                    args.sampling_rate,
                                                    win_length, psd_dict)
    else:
        loss_perceptual_entropy = None
    # Training
    for epoch in range(start_epoch + 1, 1 + args.max_epochs):
        start_t_train = time.time()
        #if args.collect_stats:
        #    collect_stats(train_loader,args)
        #    break
        train_info = train_one_epoch(train_loader, model, device, optimizer,
                                     loss, loss_perceptual_entropy, epoch,
                                     args)
        end_t_train = time.time()

        out_log = 'Train epoch: {:04d}, '.format(epoch)
        if args.optimizer == "noam":
            out_log += 'lr: {:.6f}, '.format(
                optimizer._optimizer.param_groups[0]['lr'])
        out_log += 'loss: {:.4f}, spec_loss: {:.4f}, '.format(
            train_info['loss'], train_info['spec_loss'])
        if args.n_mels > 0:
            out_log += 'mel_loss: {:.4f}, '.format(train_info['mel_loss'])
        if args.perceptual_loss > 0:
            out_log += 'pe_loss: {:.4f}, '.format(train_info['pe_loss'])
        print("{} time: {:.2f}s".format(out_log, end_t_train - start_t_train))

        start_t_dev = time.time()
        dev_info = validate(dev_loader, model, device, loss,
                            loss_perceptual_entropy, epoch, args)
        end_t_dev = time.time()

        dev_log = 'Dev epoch: {:04d}, loss: {:.4f}, spec_loss: {:.4f}, '.format(
            epoch, dev_info['loss'], dev_info['spec_loss'])
        if args.n_mels > 0:
            dev_log += 'mel_loss: {:.4f}, '.format(dev_info['mel_loss'])
        if args.perceptual_loss > 0:
            dev_log += 'pe_loss: {:.4f}, '.format(dev_info['pe_loss'])
        print("{} time: {:.2f}s".format(dev_log, end_t_dev - start_t_train))

        print("")
        sys.stdout.flush()

        if not os.path.exists(args.model_save_dir):
            os.makedirs(args.model_save_dir)

        if args.optimizer == "noam":
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer._optimizer.state_dict(),
                }, "{}/epoch_{}.pth.tar".format(args.model_save_dir, epoch))
        else:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                }, "{}/epoch_{}.pth.tar".format(args.model_save_dir, epoch))

        # record training and validation information
        if args.use_tfboard:
            record_info(train_info, dev_info, epoch, logger)

    if args.use_tfboard:
        logger.close()
示例#18
0
            test_f1_crf, test_pre_crf, test_rec_crf, test_acc_crf, test_f1_scrf, test_pre_scrf, test_rec_scrf, test_acc_scrf, test_f1_jnt, test_pre_jnt, test_rec_jnt, test_acc_jnt = \
                        evaluator.calc_score(model, test_dataset_loader)

            best_test_f1_crf = test_f1_crf
            best_test_f1_scrf = test_f1_scrf

            best_dev_f1_jnt = dev_f1_jnt
            best_test_f1_jnt = test_f1_jnt

            try:
                utils.save_checkpoint({
                        'epoch': args.start_epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'f_map': f_map,
                        'c_map': c_map,
                        'SCRF_l_map': SCRF_l_map,
                        'CRF_l_map': CRF_l_map,
                        'in_doc_words': in_doc_words,
                        'ALLOW_SPANLEN': args.allowspan
                    }, {'args': vars(args)
                        }, args.checkpoint + str(seed))
            except Exception as inst:
                    print(inst)

        else:
            early_stop_epochs += 1

        print('best_test_f1_crf is: %.4f' % (best_test_f1_crf))
        print('best_test_f1_scrf is: %.4f' % (best_test_f1_scrf))
        print('best_test_f1_jnt is: %.4f' % (best_test_f1_jnt))
示例#19
0
        log_value('dialect', stop['dialect'], epoch)
        log_value('age', stop['age'], epoch)
        log_value('height', stop['height'], epoch)

        log_value('stop_criterion_id', criterion['id'], epoch)
        log_value('stop_criterion_gender', criterion['gender'], epoch)
        log_value('stop_criterion_dialect', criterion['dialect'], epoch)
        log_value('stop_criterion_age', criterion['age'], epoch)
        log_value('stop_criterion_height', criterion['height'], epoch)

    # Checkpointing
    save_checkpoint(
        {
            'args': args,
            'epoch': epoch,
            'netE_state_dict': netE.state_dict(),
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict()
        }, os.path.join(args.outf, 'checkpoints'),
        'checkpoint_epoch_{:d}.pth.tar'.format(epoch))

    # Delete old checkpoint to save space
    new_record_fn = os.path.join(args.outf, 'checkpoints',
                                 'checkpoint_epoch_{:d}.pth.tar'.format(epoch))
    if os.path.exists(old_record_fn) and os.path.exists(new_record_fn):
        os.remove(old_record_fn)
    old_record_fn = new_record_fn

# Write log
# test_err_meter.save('test_err', os.path.join(args.outf, 'records'), 'test_err.tsv')