예제 #1
0
    def build_model(self, load_model=False):
        labeldist = self.labeldist
        ls_weight = self.config['ls_weight']

        self.model = cc_model(
            E2E(input_dim=self.config['input_dim'],
                enc_hidden_dim=self.config['enc_hidden_dim'],
                enc_n_layers=self.config['enc_n_layers'],
                subsample=self.config['subsample'],
                enc_output_dim=self.config['enc_output_dim'],
                dropout_rate=self.config['dropout_rate'],
                dec_hidden_dim=self.config['dec_hidden_dim'],
                att_dim=self.config['att_dim'],
                conv_channels=self.config['conv_channels'],
                conv_kernel_size=self.config['conv_kernel_size'],
                att_odim=self.config['att_odim'],
                output_dim=len(self.vocab),
                embedding_dim=self.config['embedding_dim'],
                ls_weight=ls_weight,
                labeldist=labeldist,
                pad=self.vocab['<PAD>'],
                bos=self.vocab['<BOS>'],
                eos=self.vocab['<EOS>']))
        print(self.model)
        self.model.float()
        self.gen_opt = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config['learning_rate'],
            weight_decay=self.config['weight_decay'])
        if load_model:
            self.load_model(self.config['load_model_path'],
                            self.config['load_optimizer'])
        return
예제 #2
0
def main():

    create_unigram_model()
    create_json()

    model = E2E(idim=80, odim=5002, args=config.ModelArgs())

    if config.use_cuda:
        model = model.cuda()
        model = nn.DataParallel(model)

    optimizer = get_std_opt(model, config.train_param['adim'],
                            config.train_param['transformer_warmup_steps'],
                            config.train_param['lr'])

    if config.resume['restart']:
        checkpoint = torch.load(config.resume['model_path'])
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch'] + 1
        losses = checkpoint['Losses']
        print(
            'Loss for the epoch:', epoch_start, ' | Avg. Loss: {0:.4f} | '
            'Avg Loss_Att: {1:.4f} | '
            'Avg Loss_CTC: {2:.4f}'.format(np.mean(losses[0]),
                                           np.mean(losses[1]),
                                           np.mean(losses[2])))
    else:
        epoch_start = 0

    train(epoch_start, model, optimizer)
예제 #3
0
파일: main.py 프로젝트: Vanova/ASSERT
def forward_pass(_run, pretrained_model_id, test_batch_size, data_files, model_params, eer_criteria, class_labels):
    """ forward pass dev and eval data to trained model  """

    use_cuda = torch.cuda.is_available()  # use cpu
    device = torch.device("cuda" if use_cuda else "cpu")
    tmp = torch.tensor([2]).to(device)

    # model is RNN  
    if model_params['MODEL_SELECT'] == 4:
        use_rnn = True
    else:
        use_rnn = False

    # model is trained with focal loss objective 
    if model_params['FOCAL_GAMMA']:
        print('training with focal loss')
        focal_obj = FocalLoss(gamma=model_params['FOCAL_GAMMA'])
    else:
        focal_obj = None

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    # create model
    # cnx
    model = E2E(**model_params).to(device)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('===> Model total parameter: {}'.format(num_params))

    if pretrained_model_id:
        pretrain_pth = 'snapshots/' + str(pretrained_model_id) + '/model_best.pth.tar'
        #pretrain_pth = '../pretrained/pa/senet34_py3'
        if os.path.isfile(pretrain_pth):
            print("===> loading checkpoint '{}'".format(pretrain_pth))
            # python 2
            checkpoint = torch.load(pretrain_pth, map_location=lambda storage, loc: storage)  # load for cpu
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # python 3
            # checkpoint = torch.load(pretrain_pth)
            # model.load_state_dict(checkpoint, strict=False)
            print("===> loaded checkpoint '{}'"
                  .format(pretrain_pth))
        else:
            print("===> no checkpoint found at '{}'".format(pretrain_pth))
            exit()

    # Data loading code (class analysis for multi-class classification only)
    val_data = SpoofDatsetSystemID(data_files['dev_scp'], data_files['dev_utt2index'], binary_class=False)
    # val_data = SpoofDatsetEval(data_files['dev_scp'])
    # eval_data = SpoofDatsetEval(data_files['eval_scp'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=test_batch_size,
                                             shuffle=False, **kwargs)
    # eval_loader = torch.utils.data.DataLoader(
    #     eval_data, batch_size=test_batch_size, shuffle=False, **kwargs)

    # forward pass for dev
    print("===> forward pass for dev set")
    score_file_pth = os.path.join(data_files['scoring_dir'], str(pretrained_model_id) + '-dev_scores.txt')
    print("===> dev scoring file saved at: '{}'".format(score_file_pth))
    prediction(val_loader, model, device, score_file_pth, data_files['dev_utt2systemID'], use_rnn, focal_obj)
예제 #4
0
파일: main.py 프로젝트: Vanova/ASSERT
def post(_run, pretrained, test_batch_size, data_files, model_params, eer_criteria, class_labels):
    """ what are the classes that performed well, and the classes that did not perform well """

    use_cuda = torch.cuda.is_available()  # use cpu
    device = torch.device("cuda" if use_cuda else "cpu")
    tmp = torch.tensor([2]).to(device)

    # model is RNN  
    if model_params['MODEL_SELECT'] == 4:
        use_rnn = True
    else:
        use_rnn = False

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    # create model
    # cnx
    model = E2E(**model_params).to(device)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('===> Model total parameter: {}'.format(num_params))

    if pretrained:
        if os.path.isfile(pretrained):
            print("===> loading checkpoint '{}'".format(pretrained))
            checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)  # load for cpu
            if eer_criteria:
                best_eer = checkpoint['best_eer']
            else:
                best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            print("===> loaded checkpoint '{}' (epoch {})"
                  .format(pretrained, checkpoint['epoch']))
        else:
            print("===> no checkpoint found at '{}'".format(pretrained))

    # Data loading code (class analysis for multi-class classification only)
    train_data = SpoofDatsetSystemID(data_files['train_scp'], data_files['train_utt2index'], binary_class=False)
    val_data = SpoofDatsetSystemID(data_files['dev_scp'], data_files['dev_utt2index'], binary_class=False)
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=test_batch_size, shuffle=False, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=test_batch_size, shuffle=False, **kwargs)

    # class analysis for train and dev 
    print("===> class analysis for train set")
    class_analysis(train_loader, model, device, use_rnn, class_labels, use_focal=model_params['FOCAL_GAMMA'])
    print("===> class analysis for dev set")
    class_analysis(val_loader, model, device, use_rnn, class_labels, use_focal=model_params['FOCAL_GAMMA'])
def main():

	args = config.ModelArgs()
	args.report_cer = True
	args.report_wer = True

	char_list = get_char_list()

	model = E2E(idim=80, odim=5002, args=args, char_list=char_list)

	if config.use_cuda:
		model = model.cuda()
		model = nn.DataParallel(model)

	checkpoint = torch.load(config.test_model)
	model.load_state_dict(checkpoint['model'])
		
	test(model)
예제 #6
0
    def build_model(self, load_model=False):
        self.model = cc(E2E(input_dim=self.config['input_dim'],
            enc_hidden_dim=self.config['enc_hidden_dim'],
            enc_n_layers=self.config['enc_n_layers'],
            subsample=self.config['subsample'],
            dropout_rate=self.config['dropout_rate'],
            dec_hidden_dim=self.config['dec_hidden_dim'],
            att_dim=self.config['att_dim'],
            conv_channels=self.config['conv_channels'],
            conv_kernel_size=self.config['conv_kernel_size'],
            att_odim=self.config['att_odim'],
            output_dim=len(self.vocab),
            embedding_dim=self.config['embedding_dim'],
            ls_weight=self.config['ls_weight'],
            labeldist=self.labeldist,
            pad=self.vocab['<PAD>'],
            bos=self.vocab['<BOS>'],
            eos=self.vocab['<EOS>']
            ))
        print(self.model)
        self.gen_opt = torch.optim.Adam(self.model.parameters(), lr=self.config['learning_rate'], 
                weight_decay=self.config['weight_decay'], amsgrad=True)
        if load_model:
            self.load_model(self.config['load_model_path'], self.config['load_optimizer'])
        print(self.gen_opt)

        self.judge = cc(LM(
                output_dim=len(self.vocab),
                embedding_dim=self.config['dis_embedding_dim'],
                hidden_dim=self.config['dis_hidden_dim'],
                dropout_rate=self.config['dis_dropout_rate'],
                n_layers=self.config['dis_layers'],
                bos=self.vocab['<BOS>'],
                eos=self.vocab['<EOS>'],
                pad=self.vocab['<PAD>'],
                ls_weight=self.config['ls_weight'],
                labeldist=self.unlab_labeldist
            ))
        print(self.judge)
        self.dis_opt = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.judge.parameters()), 
                lr=self.config['d_learning_rate']) 
        return
예제 #7
0
def main():

    args = config.ModelArgs()
    args.report_cer = True
    args.report_wer = True

    char_list = get_char_list()

    model = E2E(idim=80, odim=5002, args=args, char_list=char_list)

    if config.use_cuda:
        model = model.cuda()
        model = nn.DataParallel(model)
        checkpoint = torch.load(config.test_model)
    else:
        moduleCheckpoint = torch.load(config.test_model,
                                      map_location={"cuda": "cpu"})
        checkpoint = {'model': {}}
        for key in moduleCheckpoint['model']:
            checkpoint['model'][key[7:]] = moduleCheckpoint['model'][key]

    model.load_state_dict(checkpoint['model'])

    test(model)
예제 #8
0
파일: main.py 프로젝트: Vanova/ASSERT
def work(_run, pretrained, batch_size, test_batch_size, epochs, start_epoch, log_interval, n_warmup_steps, data_files,
         model_params, eer_criteria, leave_one_out):
    global best_acc1
    global best_eer

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    tmp = torch.tensor([2]).to(device)

    # model is trained for binary classification (for datalaoder) 
    if model_params['NUM_SPOOF_CLASS'] == 2:
        binary_class = True
    else:
        binary_class = False

    # model is RNN  
    if model_params['MODEL_SELECT'] == 4:
        use_rnn = True
    else:
        use_rnn = False

    if model_params['FOCAL_GAMMA']:
        print('training with focal loss')
        focal_obj = FocalLoss(gamma=model_params['FOCAL_GAMMA'])
    else:
        focal_obj = None

    kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

    # create model
    # cnx
    model = E2E(**model_params).to(device)
    model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('===> Model total parameter: {}'.format(model_params))

    # Wrap model for multi-GPUs, if necessary
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        print('multi-gpu')
        model = nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    # cnx 
    optimizer = ScheduledOptim(
        torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-4, amsgrad=True),
        n_warmup_steps)

    # optionally resume from a checkpoint
    if pretrained:
        if os.path.isfile(pretrained):
            print("===> loading checkpoint '{}'".format(pretrained))
            checkpoint = torch.load(pretrained)
            start_epoch = checkpoint['epoch']
            if eer_criteria and 'best_eer' in checkpoint:
                best_eer = checkpoint['best_eer']
            else:
                best_acc1 = checkpoint['best_acc1']
                print("===> Best accuracy: %f" % best_acc1)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("===> loaded checkpoint '{}' (epoch {})"
                  .format(pretrained, checkpoint['epoch']))
        else:
            print("===> no checkpoint found at '{}'".format(pretrained))

    cudnn.benchmark = True  # It enables benchmark mode in cudnn.

    # Data loading code
    train_data = SpoofDatsetSystemID(data_files['train_scp'], data_files['train_utt2index'],
                                     rnd_nslides=True,
                                     binary_class=binary_class,
                                     leave_one_out=leave_one_out)
    val_data = SpoofDatsetSystemID(data_files['dev_scp'], data_files['dev_utt2index'],
                                   binary_class=binary_class,
                                   leave_one_out=leave_one_out)
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=test_batch_size, shuffle=False, **kwargs)
    if leave_one_out:
        eval_data = SpoofDatsetSystemID(data_files['dev_scp'], data_files['dev_utt2index'], binary_class, False)
        eval_loader = torch.utils.data.DataLoader(
            eval_data, batch_size=test_batch_size, shuffle=False, **kwargs)
    else:
        eval_loader = val_loader

    best_epoch = 0
    early_stopping, max_patience = 0, 5  # for early stopping
    for epoch in range(start_epoch, start_epoch + epochs):

        # train for one epoch
        train(train_loader, model, optimizer, epoch, device, log_interval, use_rnn, focal_obj)

        # evaluate on validation set
        acc1 = validate(val_loader, data_files['dev_utt2systemID'], model, device, log_interval,
                        use_rnn, eer_criteria, focal_obj)

        # remember best acc@1/eer and save checkpoint
        if eer_criteria:
            is_best = acc1 < best_eer
            best_eer = min(acc1, best_eer)
        else:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        # adjust learning rate + early stopping 
        if is_best:
            early_stopping = 0
            best_epoch = epoch + 1
        else:
            early_stopping += 1
            if epoch - best_epoch > 2:
                optimizer.increase_delta()
                best_epoch = epoch + 1
        if early_stopping == max_patience:
            break

        # save model
        if not is_best:
            continue
        if eer_criteria:
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_eer': best_eer,
                'optimizer': optimizer.state_dict(),
            }, is_best, "snapshots/" + str(_run._id), str(epoch) + ('_%.3f' % acc1) + ".pth.tar")
        else:
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best, "snapshots/" + str(_run._id), str(epoch) + ('_%.3f' % acc1) + ".pth.tar")

    # load best model 
    best_model_pth = os.path.join('snapshots', str(_run._id), 'model_best.pth.tar')
    score_file_pth = os.path.join(data_files['scoring_dir'], str(_run._id) + '-scores.txt')
    print("===> loading best model for scoring: '{}'".format(best_model_pth))
    checkpoint = torch.load(best_model_pth)
    model.load_state_dict(checkpoint['state_dict'])

    # compute EER 
    print("===> scoring file saved at: '{}'".format(score_file_pth))
    prediction(eval_loader, model, device, score_file_pth, data_files['dev_utt2systemID'], use_rnn, focal_obj)