def test_quadrature_acquisition_gradient_values(aq):
    func = lambda x: aq.evaluate(x)[:, 0]
    dfunc = lambda x: aq.evaluate_with_gradients(x)[1].T
    check_grad(func,
               dfunc,
               in_shape=(3, 2),
               bounds=aq.model.X.shape[1] * [(-3, 3)])
示例#2
0
def test_warped_model_gradient_values(model, data):

    # gradient of mean
    func = lambda z: model.predict(z)[0][:, 0]
    dfunc = lambda z: model.get_prediction_gradients(z)[0].T
    check_grad(func, dfunc, in_shape=(5, data.D), bounds=data.dat_bounds)

    # gradient of var
    func = lambda z: model.predict(z)[1][:, 0]
    dfunc = lambda z: model.get_prediction_gradients(z)[1].T
    check_grad(func, dfunc, in_shape=(5, data.D), bounds=data.dat_bounds)
示例#3
0
def train(model, criterion, optimizer, epoch, train_losses):
    total = 0  # Reset every plot_every
    model.train()
    train_enum = tqdm(train_loader, desc='Train epoch %d' % epoch)

    for full_txt, full_feat, spkr in train_enum:
        batch_iter = TBPTTIter(full_txt, full_feat, spkr, args.seq_len)
        batch_total = 0

        for txt, feat, spkr, start in batch_iter:
            input = wrap(txt)
            target = wrap(feat)
            spkr = wrap(spkr)

            # Zero gradients
            if start:
                optimizer.zero_grad()

            # Forward
            output, _ = model([input, spkr], target[0], start)
            loss = criterion(output, target[0], target[1])

            # Backward
            loss.backward()
            if check_grad(model.parameters(), args.clip_grad,
                          args.ignore_grad):
                logging.info('Not a finite gradient or too big, ignoring.')
                optimizer.zero_grad()
                continue
            optimizer.step()

            # Keep track of loss
            batch_total += loss.data[0]

        batch_total = batch_total / len(batch_iter)
        total += batch_total
        train_enum.set_description('Train (loss %.2f) epoch %d' %
                                   (batch_total, epoch))

    avg = total / len(train_loader)
    train_losses.append(avg)
    if args.visualize:
        vis.line(Y=np.asarray(train_losses),
                 X=torch.arange(1, 1 + len(train_losses)),
                 opts=dict(title="Train"),
                 win='Train loss ' + args.expName)

    logging.info('====> Train set loss: {:.4f}'.format(avg))

    return avg
示例#4
0
def train(net, criterion, optimizer, train_losses, train_params,
          train_loss_log_file, dataloader, cuda_available):

    total_loss = 0
    num_trained = 0

    net.train()

    for i_batch, (b_x, b_y, lengths) in enumerate(dataloader):

        optimizer.zero_grad()
        input, target, lengths = b_x, b_y, lengths

        batch_size = b_x.size(0)

        if cuda_available:
            input = b_x.cuda(async=True)
            target = b_y.cuda(async=True)
            lengths = lengths.cuda(async=True)

        target = Variable(target)
        input = Variable(input)
        lengths = Variable(lengths)

        outputs = net(input)
        loss = criterion(outputs, target, lengths)
        #loss = criterion(outputs, target)
        loss.backward()

        if check_grad(net.parameters(), train_params['clip_grad'],
                      train_params['ignore_grad']):
            #print('Not a finite gradient or too big, ignoring.')
            optimizer.zero_grad()
            continue

        #print("loss ", loss.data[0])
        #total_loss += loss.data[0]
        total_loss += (loss.data[0] / batch_size)
        num_trained += 1

        if num_trained % train_params['print_every'] == 0:
            avg_loss = total_loss / train_params['print_every']
            print(num_trained, " ) loss is ", avg_loss)

            train_losses.append(avg_loss)
            train_loss_log_file.writelines(
                '====> Train set loss: {:.4f}'.format(avg_loss))
            train_loss_log_file.flush()
            total_loss = 0
示例#5
0
def test_qkernel_gradient_values(kernel_embedding):
    emukit_qkernel, x1, x2, N, M, D, dat_bounds = kernel_embedding

    np.random.seed(42)
    x1 = sample_uniform(in_shape=(N, D), bounds=dat_bounds)
    x2 = sample_uniform(in_shape=(M, D), bounds=dat_bounds)

    # dKdiag_dx
    in_shape = x1.shape
    func = lambda x: np.diag(emukit_qkernel.K(x, x))
    dfunc = lambda x: emukit_qkernel.dKdiag_dx(x)
    check_grad(func, dfunc, in_shape, dat_bounds)

    # dK_dx1
    in_shape = x1.shape
    func = lambda x: emukit_qkernel.K(x, x2)
    dfunc = lambda x: emukit_qkernel.dK_dx1(x, x2)
    check_grad(func, dfunc, in_shape, dat_bounds)

    # dK_dx2
    in_shape = x2.shape
    func = lambda x: emukit_qkernel.K(x1, x)
    dfunc = lambda x: emukit_qkernel.dK_dx2(x1, x)
    check_grad(func, dfunc, in_shape, dat_bounds)

    # dqK_dx
    in_shape = x2.shape
    func = lambda x: emukit_qkernel.qK(x)
    dfunc = lambda x: emukit_qkernel.dqK_dx(x)
    check_grad(func, dfunc, in_shape, dat_bounds)

    # dKq_dx
    in_shape = x1.shape
    func = lambda x: emukit_qkernel.Kq(x).T
    dfunc = lambda x: emukit_qkernel.dKq_dx(x).T
    check_grad(func, dfunc, in_shape, dat_bounds)
示例#6
0
def train(net, criterion, optimizer, train_losses, train_params,
          train_loss_log_file, dataloader, cuda_available):

    total_loss = 0
    num_trained = 0

    net.train()

    for i_batch, sample_batch in enumerate(dataloader):

        optimizer.zero_grad()
        music_spec = sample_batch
        if cuda_available:
            music_spec = music_spec.cuda(async=True)

        target_spec = Variable(music_spec.view(-1))
        music_spec = Variable(music_spec)

        outputs = net(music_spec)
        loss = criterion(outputs, target_spec)
        loss.backward()

        if check_grad(net.parameters(), train_params['clip_grad'],
                      train_params['ignore_grad']):
            print('Not a finite gradient or too big, ignoring.')
            optimizer.zero_grad()
            continue

        optimizer.step()
        total_loss += loss.data[0]
        num_trained += 1

        if num_trained % train_params['print_every'] == 0:
            avg_loss = total_loss / train_params['print_every']
            print(num_trained, " ) loss is ", avg_loss)

            train_losses.append(avg_loss)
            train_loss_log_file.writelines(
                '====> Train set loss: {:.4f}'.format(avg_loss))
            train_loss_log_file.flush()
            total_loss = 0
示例#7
0
def train():

    cuda_available = torch.cuda.is_available()
    train_params, model_params, dataset_params = get_arguments()
    net = WavenetAutoencoder(**model_params)
    epoch_trained = 0
    if train_params['restore_model']:
        net = load_model(net, train_params['restore_dir'],
                         train_params['restore_model'])
        if net is None:
            print("Initialize network and train from scratch.")
            net = WavenetAutoencoder(**model_params)
        else:
            #epoch_trained = train_params["restore_model"].split('.')[0]
            #epoch_trained = int(epoch_trained[7:])
            epoch_trained = 0
    dataloader = audio_data_loader(**dataset_params)

    if cuda_available is False:
        warnings.warn(
            "Cuda is not avalable, can not train model using multi-gpu.")
    if cuda_available:
        # Remove train_params "device_ids" for single GPU
        if train_params["device_ids"]:
            batch_size = dataset_params["batch_size"]
            num_gpu = len(train_params["device_ids"])
            assert batch_size % num_gpu == 0
            net = nn.DataParallel(net, device_ids=train_params['device_ids'])
        torch.backends.cudnn.benchmark = True
        net = net.cuda()

    optimizer = get_optimizer(net, train_params['optimizer'],
                              train_params['learning_rate'],
                              train_params['momentum'])

    loss_func = nn.CrossEntropyLoss()
    if cuda_available:
        loss_func = loss_func.cuda()
    if not os.path.exists(train_params['log_dir']):
        os.makedirs(train_params['log_dir'])
    if not os.path.exists(train_params['restore_dir']):
        os.makedirs(train_params['restore_dir'])
    loss_log_file = open(train_params['log_dir'] + 'loss_log.log', 'a')
    store_log_file = open(train_params['log_dir'] + 'store_log.log', 'a')

    total_loss = 0
    with open(train_params['log_dir'] + 'loss_log.log', 'r') as f:
        lines = f.readlines()
        if len(lines) > 0:
            num_trained = lines[-1].split(' ')[2]
            num_trained = int(num_trained)
        else:
            num_trained = 0
    f.close()

    # Add print for start of training time
    time = str(datetime.now())
    line = 'Training Started at' + str(time) + ' !!! \n'
    loss_log_file.writelines(line)
    loss_log_file.flush()

    for epoch in range(train_params['num_epochs']):
        net.train()
        for i_batch, sample_batch in enumerate(dataloader):

            optimizer.zero_grad()
            music_piece = sample_batch['audio_piece']
            target_piece = sample_batch['audio_target']
            if cuda_available:
                music_piece = music_piece.cuda(async=True)
                target_piece = target_piece.cuda(async=True)
            print("music_piece size = ", music_piece.size())
            music_piece = Variable(music_piece)
            target_piece = Variable(target_piece.view(-1))
            outputs = net(music_piece)

            print('target size = ', target_piece.data.size())
            print('outputs size = ', outputs.data.size())

            loss = loss_func(outputs, target_piece)
            print("loss is ", loss)

            loss.backward()
            if check_grad(net.parameters(), train_params['clip_grad'],
                          train_params['ignore_grad']):
                print('Not a finite gradient or too big, ignoring.')
                optimizer.zero_grad()
                continue

            optimizer.step()
            total_loss += loss.data[0]

            print(num_trained)
            num_trained += 1

            if num_trained % train_params['print_every'] == 0:
                avg_loss = total_loss / train_params['print_every']
                line = 'Average loss is ' + str(avg_loss) + '\n'
                loss_log_file.writelines(line)
                loss_log_file.flush()
                total_loss = 0

        if (epoch + 1) % train_params['check_point_every'] == 0:
            stored_models = glob.glob(train_params['restore_dir'] + '*.model')
            if len(stored_models) == train_params['max_check_points']:

                def cmp(x, y):
                    x = os.path.splitext(x)[0]
                    x = os.path.split(x)[-1]
                    y = os.path.splitext(y)[0]
                    y = os.path.split(y)[-1]
                    x = int(x[7:])
                    y = int(y[7:])
                    return x - y

                sorted_models = sorted(stored_models, keys=cmp_to_key(cmp))
                os.remove(sorted_models[0])
            print(epoch_trained)
            save_model(net, epoch_trained + epoch + 1,
                       train_params['restore_dir'])
            line = 'Epoch' + str(epoch_trained + epoch + 1) + 'model saved!'
            store_log_file.writelines(line)
            store_log_file.flush()

    # Add print for end of training time
    time = str(datetime.now())
    line = 'Training Ended at' + str(time) + ' !!! \n'
    loss_log_file.writelines(line)
    loss_log_file.flush()
    loss_log_file.close()
    store_log_file.close()
示例#8
0
文件: model.py 项目: ElimsV/DD2424
    # text = ''
    # for i in int_list:
    #     text = text + int2char[i]
    # print(text)

    ######test compute loss######
    # init hidden state
    # h0 = np.random.standard_normal([rnn_net.m, 1])
    h0 = np.zeros([5, 1])
    seq_len = 25
    # init rnn network
    rnn_net = RNN(data_loader.K, h0, seq_len, data_loader.unique_chars)

    X_onehot = np.zeros([rnn_net.K, seq_len])
    target_onehot = np.zeros([rnn_net.K, seq_len])
    X_int = [char2int[ch] for ch in file_data[:seq_len]]
    target_int = [char2int[ch] for ch in file_data[1:seq_len + 1]]
    X_onehot[X_int, range(seq_len)] = 1
    target_onehot[target_int, range(seq_len)] = 1

    # prob = rnn_net.predict_prob(X_onehot)
    # loss = rnn_net.compute_loss(prob, target_onehot)
    # print(loss)

    ######test gradient computation######
    # grads, loss = rnn_net.backward_pass(X_onehot, target_onehot)
    # for grad in grads:
    #     print(grad)

    check_grad(rnn_net, X_onehot, target_onehot)
示例#9
0
def test_measure_gradient_values(measure):
    D, measure, dat_bounds = measure.D, measure.measure, measure.dat_bounds
    func = lambda x: measure.compute_density(x)
    dfunc = lambda x: measure.compute_density_gradient(x).T
    check_grad(func, dfunc, in_shape=(3, D), bounds=dat_bounds)
示例#10
0
def train(train_loader_source, train_loader_source_batches,
          train_loader_target, train_loader_target_batches, feature_extractor,
          class_classifier, domain_classifier, criterion_y, criterion_d,
          optimizer, epoch, args):
    """
    Train for one epoch. Only a batch is used in a epoch, not all the batches.
    Parameters
    ----------
    train_loader_source: torch.utils.data.DataLoader
        Used to reset train_loader_source_batches if the enumerate reach the end of iteration
    train_loader_source_batches: enumerate 
        An object whose each element contain one batch of source data
    train_loader_target: torch.utils.data.DataLoader
        Used to reset train_loader_target_batches if the enumerate reach the end of iteration
    train_loader_target_batches: enumerate
        An object whose each element contain one batch of target data
    model: pytorch model
        The model in training pipeline
    criterion_y: A certain class of loss in torch.nn
        The criterion of the label predicter model
    criterion_d: A certain class of loss in torch.nn
        The criterion of the domain classifier model        
    optimizer_C: An optimizer in a certain update principle in torch.optim
        The optimizer for classifier of the model 
    optimizer_G: An optimizer in a certain update principle in torch.optim
        The optimizer for feature extracter of the model 
    args: Namespace
        Arguments that main.py receive
    epoch: int
        The current epoch
    Return
    ------
    pred_acc1_source: float
        The top1 accuracy in this minibatch
    loss_total_train: float
        The loss in this minibatch
    """
    #     model.train()
    feature_extractor.train()
    class_classifier.train()
    domain_classifier.train()

    adjust_learning_rate(optimizer, epoch, args)
    for param_group in optimizer.param_groups:
        print(param_group['lr'])
    end = time.time()

    # prepare the data for the model forward and backward
    # note that DANN is used on the condition that the label of target dataset is not available
    new_epoch_flag = False
    try:
        _, (inputs_source,
            labels_source) = train_loader_source_batches.__next__()

    except StopIteration:
        if args.epoch_count_dataset == 'source':
            epoch = epoch + 1
            new_epoch_flag = True
        train_loader_source_batches = enumerate(train_loader_source)
        _, (inputs_source,
            labels_source) = train_loader_source_batches.__next__()

    try:
        _, (inputs_target, _) = train_loader_target_batches.__next__()
    except StopIteration:
        if args.epoch_count_dataset == 'target':
            epoch = epoch + 1
            new_epoch_flag = True
        train_loader_target_batches = enumerate(train_loader_target)
        _, (inputs_target, _) = train_loader_target_batches.__next__()

    if torch.cuda.is_available():
        inputs_source = inputs_source.cuda(async=True)
        labels_source = labels_source.cuda(async=True)
    inputs_source_var, labels_source_var = Variable(inputs_source), Variable(
        labels_source)

    if torch.cuda.is_available():
        inputs_target = inputs_target.cuda(async=True)
    inputs_target_var = Variable(inputs_target)
    data_time_train.update(time.time() - end)

    # compute the output of source domain and target domain
    feature_source = feature_extractor(inputs_source)
    feature_target = feature_extractor(inputs_target)

    # compute the class loss of feature_source
    outputs_source = class_classifier(feature_source)
    outputs_target = class_classifier(feature_target)
    loss_C = criterion_y(outputs_source, labels_source)

    # prepare domain labels
    if torch.cuda.is_available():
        source_labels = Variable(
            torch.zeros(
                (inputs_source.size()[0])).type(torch.LongTensor).cuda())
        target_labels = Variable(
            torch.ones(
                (inputs_target.size()[0])).type(torch.LongTensor).cuda())
    else:
        source_labels = Variable(
            torch.zeros((inputs_source.size()[0])).type(torch.LongTensor))
        target_labels = Variable(
            torch.ones((inputs_target.size()[0])).type(torch.LongTensor))

    # compute the domain loss of feature_source and target_feature
    p = float(epoch) / args.epochs
    constant = 2. / (1. + np.exp(-args.gamma * p)) - 1
    preds_source = domain_classifier(feature_source, constant)
    preds_target = domain_classifier(feature_target, constant)
    domain_loss_source = criterion_d(preds_source, source_labels)
    domain_loss_target = criterion_d(preds_target, target_labels)
    loss_G = domain_loss_target + domain_loss_source

    loss = loss_C + loss_G
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    grad_mean_extractor = check_grad(feature_extractor)
    grad_mean_class = check_grad(class_classifier)
    grad_mean_domain = check_grad(domain_classifier)

    #     if new_epoch_flag or epoch == 0:
    if new_epoch_flag or epoch == 0:
        with torch.no_grad():
            #         outputs_tmp = copy.deepcopy(outputs_Cst_target)
            writer = SummaryWriter(log_dir=args.log)
            o_minimum, o_maximum, o_medium = analyze_output(outputs_target)
            writer.add_scalars(
                'data/output_analysis', {
                    'o_minimum': o_minimum,
                    'o_maximum': o_maximum,
                    'o_medium': o_medium
                }, epoch)
            writer.add_scalars(
                'data/scalar_group', {
                    'grad_mean_extractor': grad_mean_extractor,
                    'grad_mean_class': grad_mean_extractor,
                    'grad_mean_domain': grad_mean_domain
                }, epoch)
            writer.add_scalars('data/insight', {
                'loss_C': loss_C,
                'domain_loss': loss_G,
                'loss': loss
            }, epoch)
            writer.close()

    # measure accuracy and record loss
    pred_acc1_source, pred_acc5_source = accuracy(outputs_source,
                                                  labels_source_var,
                                                  topk=(1, 5))
    losses_C_train.update(loss_C.data)
    losses_G_train.update(loss_G.data)
    loss_total_train = loss_C + loss_G
    losses_total_train.update(loss_total_train.data)
    top1_source_train.update(pred_acc1_source)
    top5_source_train.update(pred_acc5_source)
    batch_time_train.update(time.time() - end)

    if epoch % args.print_freq == 0:
        print(
            'Tr epoch [{0}/{1}]\t'
            'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'S@1 {top1_source.val:.3f} ({top1_source.avg:.3f})\t'
            'S@5 {top5_source.val:.3f} ({top5_source.avg:.3f})\t'
            'loss_C {loss_C_train.val:.4f} ({loss_C_train.avg:.4f})\t'
            'loss_G {loss_G_train.val:.4f} ({loss_G_train.avg:.4f})\t'.format(
                epoch,
                args.epochs,
                batch_time=batch_time_train,
                data_time=data_time_train,
                top1_source=top1_source_train,
                top5_source=top5_source_train,
                loss_C_train=losses_C_train,
                loss_G_train=losses_G_train))
    if epoch % args.record_freq == (args.record_freq - 1):
        if not os.path.isdir(args.log):
            os.mkdir(args.log)
        with open(os.path.join(args.log, 'log.txt'), 'a+') as fp:
            fp.write('\n')
            fp.write(
                'Tr:epoch: %d, loss_total: %4f,'
                'top1_source acc: %3f, top5_source acc: %3f, loss_C: %4f, loss_G: %4f'
                % (epoch, losses_total_train.avg, top1_source_train.avg,
                   top5_source_train.avg, losses_C_train.avg,
                   losses_G_train.avg))

#     return pred_acc1_source, loss_total_train
#     return pred_acc1_source, loss_C, loss_G
    return train_loader_source_batches, train_loader_target_batches, epoch, pred_acc1_source, loss_C, loss_G, new_epoch_flag
示例#11
0
文件: main.py 项目: bliunlpr/speaker
def train(opt, model, optimizer):
    model.train()
    total_steps = opt.total_steps
    losses = utils.AverageMeter()
    embedding_losses = utils.AverageMeter()
    embedding_segment_losses = utils.AverageMeter()
    penalty_losses = utils.AverageMeter()
    lr = opt.lr

    for i, (data) in enumerate(train_loader, start=0):
        if opt.seq_training == 'true':
            feature_input, seq_len, spk_ids = data
            seq_len = seq_len.squeeze(0).to(device)
            if opt.model_type == 'lstm':
                feature_input = feature_input.squeeze(0).to(device)
            elif opt.model_type == 'cnn':
                feature_input = feature_input.transpose(1, 2).transpose(
                    0, 1).to(device)
            else:
                raise Exception('wrong model_type {}'.format(opt.model_type))

            outputs, attention_matrix, segment_outputs = model(
                feature_input, seq_len)
            sim_matrix = similarity(outputs, model.w, model.b, opt)
            embedding_loss = opt.embedding_loss_lamda * loss_cal(
                sim_matrix, opt)

            if opt.segment_type == 'average':
                sim_matrix = similarity(segment_outputs, model.w, model.b, opt)
                embedding_loss_segment = opt.segment_loss_lamda * loss_cal(
                    sim_matrix, opt)
                embedding_segment_losses.update(embedding_loss_segment.item())
            elif opt.segment_type == 'all':
                sim_matrix = similarity_segment(segment_outputs, seq_len,
                                                model.w, model.b, opt)
                embedding_loss_segment = opt.segment_loss_lamda * loss_cal_segment(
                    sim_matrix, seq_len, opt)
                embedding_segment_losses.update(embedding_loss_segment.item())
            else:
                embedding_loss_segment = 0

            if opt.train_type == 'multi_attention':
                penalty_loss = opt.penalty_loss_lamda * penalty_seq_loss_cal(
                    attention_matrix, device)
                penalty_losses.update(penalty_loss.item())
            else:
                penalty_loss = 0
            loss = embedding_loss + penalty_loss + embedding_loss_segment
        else:
            feature_input, spk_ids = data
            if opt.model_type == 'lstm':
                feature_input = feature_input.squeeze(0).to(device)
                outputs, attention_matrix = model(feature_input)
                sim_matrix = similarity(outputs, model.w, model.b, opt)
                embedding_loss = opt.embedding_loss_lamda * loss_cal(
                    sim_matrix, opt)
                if opt.train_type == 'multi_attention':
                    penalty_loss = opt.penalty_loss_lamda * penalty_loss_cal(
                        attention_matrix, device)
                    penalty_losses.update(penalty_loss.item())
                else:
                    penalty_loss = 0
                loss = embedding_loss + penalty_loss
            elif opt.model_type == 'cnn':
                feature_input = feature_input.transpose(1, 2).transpose(
                    0, 1).to(device)
                outputs = model(feature_input)
                sim_matrix = similarity(outputs, model.w, model.b, opt)
                embedding_loss = opt.embedding_loss_lamda * loss_cal(
                    sim_matrix, opt)
                loss = embedding_loss
            else:
                raise Exception('wrong model_type {}'.format(opt.model_type))

        # Backward
        optimizer.zero_grad()
        loss.backward()
        if utils.check_grad(model.parameters(), opt.clip_grad,
                            opt.ignore_grad):
            logging.info('Not a finite gradient or too big, ignoring.')
            optimizer.zero_grad()
            continue
        optimizer.step()

        losses.update(loss.item())
        embedding_losses.update(embedding_loss.item())

        if total_steps % opt.print_freq == 0:
            logging.info(
                '==> Train set steps {} lr: {:.6f}, loss: {:.4f} [ embedding: {:.4f}, embedding_segment: {:.4f}, penalty_loss {:.4f}]'
                .format(total_steps, lr, losses.avg, embedding_losses.avg,
                        embedding_segment_losses.avg, penalty_losses.avg))
            state = {
                'state_dict': model.state_dict(),
                'opt': opt,
                'learning_rate': lr,
                'total_steps': total_steps
            }
            filename = 'latest'
            utils.save_checkpoint(state, opt.expr_dir, filename=filename)

        if total_steps % opt.validate_freq == 0:
            EER = evaluate(opt, model)
            lr = utils.adjust_learning_rate_by_factor(optimizer, lr,
                                                      opt.lr_reduce_factor)
            state = {
                'state_dict': model.state_dict(),
                'opt': opt,
                'learning_rate': lr,
                'total_steps': total_steps
            }
            filename = 'steps-{}_lr-{:.6f}_EER-{:.4f}.pth'.format(
                total_steps, lr, EER)
            utils.save_checkpoint(state, opt.expr_dir, filename=filename)
            model.train()
            losses.reset()
            embedding_losses.reset()
            embedding_segment_losses.reset()
            penalty_losses.reset()
        total_steps += 1
        if total_steps > opt.training_total_steps:
            logging.info(
                'finish training, total_steps is  {}'.format(total_steps))
            break