示例#1
0
文件: rsi.py 项目: Allen1203/quant
class RSI(object):

    def __init__(self, period):
        self.value = None
        self.last = None
        self.ema_u = EMA(period)
        self.ema_d = EMA(period)
        self.tbl = None
    
    def setupH5(self, h5file, h5where, h5name):
        if h5file != None and h5where != None and h5name != None:
            self.tbl = h5file.createTable(h5where, h5name, RSIData)

    def update(self, value, date=None):
        if self.last == None:
            self.last = value
        
        U = value - self.last
        D = self.last - value

        self.last = value

        if U > 0:
            D = 0
        elif D > 0:
            U = 0

        self.ema_u.update(U)
        self.ema_d.update(D)

        if self.ema_d.value == 0:
            self.value = 100.0
        else:
            rs = self.ema_u.value / self.ema_d.value
            self.value = 100.0 - (100.0 / (1 + rs))

        if self.tbl != None and date:
            self.tbl.row["date"] = date.date().toordinal()
            self.tbl.row["value"] = self.value
            self.tbl.row.append()
            self.tbl.flush()

        return self.value
示例#2
0
class RSI(object):
    def __init__(self, period):
        self.value = None
        self.last = None
        self.ema_u = EMA(period)
        self.ema_d = EMA(period)
        self.tbl = None

    def setupH5(self, h5file, h5where, h5name):
        if h5file != None and h5where != None and h5name != None:
            self.tbl = h5file.createTable(h5where, h5name, RSIData)

    def update(self, value, date=None):
        if self.last == None:
            self.last = value

        U = value - self.last
        D = self.last - value

        self.last = value

        if U > 0:
            D = 0
        elif D > 0:
            U = 0

        self.ema_u.update(U)
        self.ema_d.update(D)

        if self.ema_d.value == 0:
            self.value = 100.0
        else:
            rs = self.ema_u.value / self.ema_d.value
            self.value = 100.0 - (100.0 / (1 + rs))

        if self.tbl != None and date:
            self.tbl.row["date"] = date.date().toordinal()
            self.tbl.row["value"] = self.value
            self.tbl.row.append()
            self.tbl.flush()

        return self.value
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--train-k',
                        type=int,
                        default=8,
                        help='number of importance weights for training')
    parser.add_argument('--test-k',
                        type=int,
                        default=50,
                        help='number of importance weights for evaluation')
    parser.add_argument('--flow',
                        type=int,
                        default=2,
                        help='number of IAF layers')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='global learning rate')
    parser.add_argument('--enc-lr',
                        type=float,
                        default=1e-4,
                        help='encoder learning rate')
    parser.add_argument('--dec-lr',
                        type=float,
                        default=1e-4,
                        help='decoder learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=-1,
                        help='min learning rate for LR scheduler. '
                        '-1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-3, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=200,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--ts',
                        type=float,
                        default=1,
                        help='log-likelihood weight for ELBO')
    parser.add_argument('--kl',
                        type=float,
                        default=.1,
                        help='KL weight for ELBO')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma',
                        type=float,
                        default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pvae = PVAE(encoder, decoder, classifier, args.sigma, args.cls).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pvae, args.ema_decay, args.ema)

    other_params = [
        param for name, param in pvae.named_parameters()
        if not (name.startswith('decoder.grid_decoder')
                or name.startswith('encoder.grid_encoder'))
    ]
    params = [
        {
            'params': decoder.grid_decoder.parameters(),
            'lr': args.dec_lr
        },
        {
            'params': encoder.grid_encoder.parameters(),
            'lr': args.enc_lr
        },
        {
            'params': other_params
        },
    ]

    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pvae, 'pvae')
        print_params_count(pvae, 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pvae,
                          val_loader,
                          test_loader,
                          log_dir,
                          eval_args={'iw_samples': args.test_k})
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()
        for (val, idx, mask, y, _, cconv_graph) in train_loader:
            optimizer.zero_grad()
            loss, _, _, loss_info = pvae(val, idx, mask, y, cconv_graph,
                                         args.train_k, args.ts, args.kl)
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            for loss_name, loss_val in loss_info.items():
                loss_breakdown[loss_name] += loss_val

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
示例#4
0
def train_val_model(pipeline_cfg, model_cfg, train_cfg):
    data_pipeline = DataPipeline(**pipeline_cfg)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if model_cfg['cxt_emb_pretrained'] is not None:
        model_cfg['cxt_emb_pretrained'] = torch.load(
            model_cfg['cxt_emb_pretrained'])
    bidaf = BiDAF(word_emb=data_pipeline.word_type.vocab.vectors, **model_cfg)
    ema = EMA(train_cfg['exp_decay_rate'])
    for name, param in bidaf.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, bidaf.parameters())
    optimizer = optim.Adadelta(parameters, lr=train_cfg['lr'])
    criterion = nn.CrossEntropyLoss()

    result = {'best_f1': 0.0, 'best_model': None}

    num_epochs = train_cfg['num_epochs']
    for epoch in range(1, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        for phase in ['train', 'val']:
            val_answers = dict()
            val_f1 = 0
            val_em = 0
            val_cnt = 0
            val_r = 0

            if phase == 'train':
                bidaf.train()
            else:
                bidaf.eval()
                backup_params = EMA(0)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        backup_params.register(name, param.data)
                        param.data.copy_(ema.get(name))

            with torch.set_grad_enabled(phase == 'train'):
                for batch_num, batch in enumerate(
                        data_pipeline.data_iterators[phase]):
                    optimizer.zero_grad()
                    p1, p2 = bidaf(batch)
                    loss = criterion(p1, batch.s_idx) + criterion(
                        p2, batch.e_idx)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        for name, param in bidaf.named_parameters():
                            if param.requires_grad:
                                ema.update(name, param.data)
                        if batch_num % train_cfg['batch_per_disp'] == 0:
                            batch_loss = loss.item()
                            print('batch %d: loss %.3f' %
                                  (batch_num, batch_loss))

                    if phase == 'val':
                        batch_size, c_len = p1.size()
                        val_cnt += batch_size
                        ls = nn.LogSoftmax(dim=1)
                        mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1). \
                            unsqueeze(0).expand(batch_size, -1, -1)
                        score = (ls(p1).unsqueeze(2) +
                                 ls(p2).unsqueeze(1)) + mask
                        score, s_idx = score.max(dim=1)
                        score, e_idx = score.max(dim=1)
                        s_idx = torch.gather(s_idx, 1,
                                             e_idx.view(-1, 1)).squeeze()

                        for i in range(batch_size):
                            answer = (s_idx[i], e_idx[i])
                            gt = (batch.s_idx[i], batch.e_idx[i])
                            val_f1 += f1_score(answer, gt)
                            val_em += exact_match_score(answer, gt)
                            val_r += r_score(answer, gt)

            if phase == 'val':
                val_f1 = val_f1 * 100 / val_cnt
                val_em = val_em * 100 / val_cnt
                val_r = val_r * 100 / val_cnt
                print('Epoch %d: %s f1 %.3f | %s em %.3f |  %s rouge %.3f' %
                      (epoch, phase, val_f1, phase, val_em, phase, val_r))
                if val_f1 > result['best_f1']:
                    result['best_f1'] = val_f1
                    result['best_em'] = val_em
                    result['best_model'] = copy.deepcopy(bidaf.state_dict())
                    torch.save(result, train_cfg['ckpoint_file'])
                    # with open(train_cfg['val_answers'], 'w', encoding='utf-8') as f:
                    #     print(json.dumps(val_answers), file=f)
                for name, param in bidaf.named_parameters():
                    if param.requires_grad:
                        param.data.copy_(backup_params.get(name))
示例#5
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default='mimic3.npz', help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=500,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        help='batch size')
    # Use smaller test batch size to accommodate more importance samples
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=32,
                        help='batch size for validation and test set')
    parser.add_argument('--lr',
                        type=float,
                        default=2e-4,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=3e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=1e-4,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=1.5e-4,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--cls',
                        type=float,
                        default=1,
                        help='classification weight')
    parser.add_argument('--clsdep',
                        type=int,
                        default=1,
                        help='number of layers for classifier')
    parser.add_argument('--eval-interval',
                        type=int,
                        default=1,
                        help='AUC evaluation interval. '
                        '0 to disable evaluation.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=1,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='mse',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--dec-ch',
                        default='8-16-16',
                        help='decoder architecture')
    parser.add_argument('--enc-ch',
                        default='64-32-32-16',
                        help='encoder architecture')
    parser.add_argument('--dis-ch',
                        default=None,
                        help='discriminator architecture. Use encoder '
                        'architecture if unspecified.')
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, normalize continuous convolutional '
                        'layer using mean pooling')
    parser.add_argument('--no-cconvnorm',
                        dest='cconv_norm',
                        action='store_const',
                        const=False)
    parser.add_argument('--cconv-ref',
                        type=int,
                        default=98,
                        help='number of evenly-spaced reference locations '
                        'for continuous convolutional layer')
    parser.add_argument('--dec-ref',
                        type=int,
                        default=128,
                        help='number of evenly-spaced reference locations '
                        'for decoder')
    parser.add_argument('--trans',
                        type=int,
                        default=2,
                        help='number of encoder layers')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=0,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    args = parser.parse_args()

    nz = args.nz

    epochs = args.epoch
    eval_interval = args.eval_interval
    save_interval = args.save_interval

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    max_time = 5
    cconv_ref = args.cconv_ref
    overlap = args.overlap
    train_dataset, val_dataset, test_dataset = time_series.split_data(
        args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            collate_fn=val_dataset.collate_fn)

    test_loader = DataLoader(test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn)

    in_channels, seq_len = train_dataset.data.shape[1:]

    if args.dis_ch is None:
        args.dis_ch = args.enc_ch

    dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
    enc_channels = [int(c) for c in args.enc_ch.split('-')]
    dis_channels = [int(c) for c in args.dis_ch.split('-')]

    out_channels = enc_channels[0]

    squash = torch.sigmoid
    if args.rescale:
        squash = torch.tanh

    dec_ch_up = 2**(len(dec_channels) - 2)
    assert args.dec_ref % dec_ch_up == 0, (
        f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
    dec_len0 = args.dec_ref // dec_ch_up
    grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)

    decoder = Decoder(grid_decoder, max_time=max_time,
                      dec_ref=args.dec_ref).to(device)
    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=overlap,
                             kernel_size=args.comp,
                             norm=args.cconv_norm).to(device)
    encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device)

    classifier = Classifier(nz, args.clsdep).to(device)

    pbigan = PBiGAN(encoder, decoder, classifier,
                    ae_loss=args.aeloss).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=overlap,
                                    kernel_size=args.comp,
                                    norm=args.cconv_norm).to(device)
    critic_embed = 32
    critic = ConvCritic(critic_cconv, nz, dis_channels,
                        critic_embed).to(device)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           betas=(0, .999),
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  betas=(0, .999),
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'mimic3-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)
    with (log_dir / 'params.txt').open('w') as f:

        def print_params_count(module, name):
            try:  # sum counts if module is a list
                params_count = sum(count_parameters(m) for m in module)
            except TypeError:
                params_count = count_parameters(module)
            print(f'{name} {params_count}', file=f)

        print_params_count(grid_decoder, 'grid_decoder')
        print_params_count(decoder, 'decoder')
        print_params_count(cconv, 'cconv')
        print_params_count(encoder, 'encoder')
        print_params_count(classifier, 'classifier')
        print_params_count(pbigan, 'pbigan')
        print_params_count(critic, 'critic')
        print_params_count([pbigan, critic], 'total')

    tracker = Tracker(log_dir, n_train_batch)
    evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir)
    start = time.time()
    epoch_start = start

    batch_size = args.batch_size

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        epoch_start = time.time()

        if epoch >= 40:
            args.cls = 200

        for ((val, idx, mask, y, _, cconv_graph),
             (_, idx_t, mask_t, _, index, _)) in zip(train_loader,
                                                     time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan(
                val, idx, mask, y, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            # Don't need pbigan.requires_grad_(False);
            # critic takes as input only the detached tensors.
            real = critic(cconv_graph, batch_size, z_enc.detach())
            detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x
                              for i, x in enumerate(cconv_graph_gen)]
            fake = critic(detached_graph, batch_size, z_gen.detach())

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward()
            critic_optimizer.step()

            for p in critic.parameters():
                p.requires_grad_(False)
            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls +
                    mmd_loss * args.mmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            for p in critic.parameters():
                p.requires_grad_(True)

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['CLS'] += cls_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
            if ema:
                ema.apply()
                evaluator.evaluate(epoch)
                ema.restore()
            else:
                evaluator.evaluate(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
示例#7
0
def train_bidaf(args, data):
    device = torch.device(
        f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    model = BiDAF(args, data.WORD.vocab.vectors).to(device)

    ema = EMA(args.exp_decay_rate)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(logdir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_exact, max_dev_f1 = -1, -1

    iterator = data.train_iter
    for i, batch in tqdm(enumerate(iterator)):
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch

        p1, p2 = model(batch)

        optimizer.zero_grad()
        # print(p1, batch.s_idx)
        # print(p2, batch.e_idx)
        batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx)
        # print('p1', p1.shape, p1)
        # print('batch.s_idx', batch.s_idx.shape, batch.s_idx.shape)
        # print(loss, batch_loss.item())
        loss += batch_loss.item()
        # print(loss)
        # print(batch_loss.item())
        batch_loss.backward()
        optimizer.step()

        for name, param in model.named_parameters():
            if param.requires_grad:
                ema.update(name, param.data)

        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_exact, dev_f1 = test(model, ema, args, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('exact_match/dev', dev_exact, c)
            writer.add_scalar('f1/dev', dev_f1, c)
            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}'
                  f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}')

            if dev_f1 > max_dev_f1:
                max_dev_f1 = dev_f1
                max_dev_exact = dev_exact
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
    print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}')

    return best_model
示例#8
0
class SNTGRunLoop(object):
    def __init__(self,
                 net,
                 dataloader=None,
                 params=None,
                 update_fn=None,
                 eval_loader=None,
                 test_loader=None,
                 has_cuda=True):
        if has_cuda:
            device = torch.device("cuda:0")
        else:
            device = torch.device('cpu')
        self.net = net.to(device)
        self.loader = dataloader
        self.eval_loader = eval_loader
        self.test_loader = test_loader
        self.params = params
        self.device = device
        # self.net.to(device)
        if params is not None:
            n_data, num_classes = params['n_data'], params['num_classes']
            n_eval_data, batch_size = params['n_eval_data'], params[
                'batch_size']
            self.ensemble_pred = torch.zeros((n_data, num_classes),
                                             device=device)
            self.target_pred = torch.zeros((n_data, num_classes),
                                           device=device)
            t_one = torch.ones(())
            self.epoch_pred = t_one.new_empty((n_data, num_classes),
                                              dtype=torch.float32,
                                              device=device)
            self.epoch_mask = t_one.new_empty((n_data),
                                              dtype=torch.float32,
                                              device=device)
            self.train_epoch_loss = \
                t_one.new_empty((n_data // batch_size, 4),
                                dtype=torch.float32, device=device)
            self.train_epoch_acc = \
                t_one.new_empty((n_data // batch_size), dtype=torch.float32,
                                device=device)
            self.eval_epoch_loss = \
                t_one.new_empty((n_eval_data // batch_size, 2),
                                dtype=torch.float32, device=device)
            self.eval_epoch_acc = \
                t_one.new_empty((n_eval_data // batch_size, 2),
                                dtype=torch.float32, device=device)
            self.optimizer = opt.Adam(self.net.parameters())
            self.update_fn = update_fn
            self.ema = EMA(params['polyak_decay'], self.net, has_cuda)
            self.unsup_weight = 0.0
        # self.loss_fn = nn.CrossEntropyLoss()

    def train(self):
        # labeled_loss = nn.CrossEntropyLoss()
        train_losses, train_accs = [], []
        eval_losses, eval_accs = [], []
        ema_eval_losses, ema_eval_accs = [], []
        for epoch in range(self.params['num_epochs']):
            # training phase
            self.net.train()
            train_time = -time.time()
            self.epoch_pred.zero_()
            self.epoch_mask.zero_()
            # self.epoch_loss.zero_()
            self.unsup_weight = self.update_fn(self.optimizer, epoch)

            for i, data_batched in enumerate(self.loader, 0):
                images, is_lens, mask, indices = \
                    data_batched['image'], data_batched['is_lens'], \
                    data_batched['mask'], data_batched['index']

                targets = torch.index_select(self.target_pred, 0, indices)
                # print(f"y value dimension:{is_lens.size()}")

                self.optimizer.zero_grad()

                outputs, h_x = self.net(images)
                # print(f"output dimension: {outputs.size()}")
                predicts = F.softmax(outputs, dim=1)

                # update for ensemble
                for k, j in enumerate(indices):
                    self.epoch_pred[j] = predicts[k]
                    self.epoch_mask[j] = 1.0

                # labeled loss
                labeled_mask = mask.eq(0)
                # loss = self.loss_fn(
                # outputs[labeled_mask], is_lens[labeled_mask])
                # labeled loss with binary entropy with logits, use one_hot
                one_hot = torch.zeros(
                    len(is_lens[labeled_mask]), is_lens[labeled_mask].max()+1,
                    device=self.device) \
                    .scatter_(1, is_lens[labeled_mask].unsqueeze(1), 1.)
                loss = F.binary_cross_entropy_with_logits(
                    outputs[labeled_mask], one_hot)
                # one_hot = torch.zeros(
                #     len(is_lens), is_lens.max() + 1, device=self.device) \
                #     .scatter_(1, is_lens.unsqueeze(1), 1.)
                # loss = F.binary_cross_entropy_with_logits(outputs, one_hot)
                # print(loss.item())

                self.train_epoch_acc[i] = \
                    torch.mean(torch.argmax(
                        outputs[labeled_mask], 1).eq(is_lens[labeled_mask])
                    .float()).item()
                # train_acc = torch.mean(
                #     torch.argmax(outputs, 1).eq(is_lens).float())
                self.train_epoch_loss[i, 0] = loss.item()

                # unlabeled loss
                unlabeled_loss = torch.mean((predicts - targets)**2)
                self.train_epoch_loss[i, 1] = unlabeled_loss.item()
                loss += unlabeled_loss * self.unsup_weight

                # SNTG loss
                if self.params['embed']:
                    half = int(h_x.size()[0] // 2)
                    eucd2 = torch.mean((h_x[:half] - h_x[half:])**2, dim=1)
                    eucd = torch.sqrt(eucd2)
                    target_hard = torch.argmax(targets, dim=1).int()
                    merged_tar = torch.where(mask == 0, target_hard,
                                             is_lens.int())
                    neighbor_bool = torch.eq(merged_tar[:half],
                                             merged_tar[half:])
                    eucd_y = torch.where(eucd < 1.0, (1.0 - eucd)**2,
                                         torch.zeros_like(eucd))
                    embed_losses = torch.where(neighbor_bool, eucd2, eucd_y)
                    embed_loss = torch.mean(embed_losses)
                    self.train_epoch_loss[i, 2] = embed_loss.item()
                    loss += embed_loss * \
                        self.unsup_weight * self.params['embed_coeff']
                self.train_epoch_loss[i, 3] = loss.item()
                loss.backward()
                self.optimizer.step()
                self.ema.update()

            self.ensemble_pred = \
                self.params['pred_decay'] * self.ensemble_pred + \
                (1 - self.params['pred_decay']) * self.epoch_pred
            self.targets_pred = self.ensemble_pred / \
                (1.0 - self.params['pred_decay'] ** (epoch + 1))
            loss_mean = torch.mean(self.train_epoch_loss, 0)
            train_losses.append(loss_mean[3].item())
            acc_mean = torch.mean(self.train_epoch_acc)
            train_accs.append(acc_mean.item())
            print(f"epoch {epoch}, time cosumed: {time.time() + train_time}, "
                  f"labeled loss: {loss_mean[0].item()}, "
                  f"unlabeled loss: {loss_mean[1].item()}, "
                  f"SNTG loss: {loss_mean[2].item()}, "
                  f"total loss: {loss_mean[3].item()}")
            # print(f"epoch {epoch}, time consumed: {time.time() + train_time}, "
            #       f"labeled loss: {loss_mean[0].item()}")

            # eval phase
            if self.eval_loader is not None:
                # none ema evaluation
                self.net.eval()
                for i, data_batched in enumerate(self.eval_loader, 0):
                    images, is_lens = data_batched['image'], \
                        data_batched['is_lens']
                    # currently h_x in evalization is not used
                    eval_logits, _ = self.ema(images)
                    self.eval_epoch_acc[i, 0] = torch.mean(
                        torch.argmax(eval_logits,
                                     1).eq(is_lens).float()).item()
                    # print(f"ema evaluation accuracy: {ema_eval_acc.item()}")
                    eval_lens = torch.zeros(
                        len(is_lens), is_lens.max()+1,
                        device=self.device) \
                        .scatter_(1, is_lens.unsqueeze(1), 1.)
                    # eval_loss = self.loss_fn(eval_logits, is_lens)
                    self.eval_epoch_loss[i, 0] = \
                        F.binary_cross_entropy_with_logits(
                        eval_logits, eval_lens).item()
                    # break
                    eval_logits, _ = self.net(images)
                    self.eval_epoch_acc[i, 1] = torch.mean(
                        torch.argmax(eval_logits,
                                     1).eq(is_lens).float()).item()
                    # print(f"evaluation accuracy: {eval_acc.item()}")
                    self.eval_epoch_loss[i, 1] = \
                        F.binary_cross_entropy_with_logits(
                        eval_logits, eval_lens).item()
                loss_mean = torch.mean(self.eval_epoch_loss, 0)
                acc_mean = torch.mean(self.eval_epoch_acc, 0)
                ema_eval_accs.append(acc_mean[0].item())
                ema_eval_losses.append(loss_mean[0].item())
                eval_accs.append(acc_mean[1].item())
                eval_losses.append(loss_mean[1].item())
                print(f"ema accuracy: {acc_mean[0].item()}"
                      f"normal accuracy: {acc_mean[1].item()}")

        return train_losses, train_accs, eval_losses, eval_accs, \
            ema_eval_losses, ema_eval_accs

    def test(self):
        self.net.eval()
        with torch.no_grad():
            for i, data_batched in enumerate(self.test_loader, 0):
                images, is_lens = data_batched['image'], data_batched[
                    'is_lens']
                start = time.time()
                test_logits, _ = self.net(images)
                end = time.time()
                result = torch.argmax(F.softmax(test_logits, dim=1), dim=1)
                accuracy = torch.mean(result.eq(is_lens).float()).item()
                # return roc_curve(is_lens, test_logits)
                return result.tolist(), is_lens.tolist(), end - start, \
                    accuracy

    def test_origin(self):
        self.net.eval()
        with torch.no_grad():
            for i, data_batched in enumerate(self.test_loader, 0):
                images, is_lens = data_batched['image'], data_batched[
                    'is_lens']
                test_logits, _ = self.net(images)
                return test_logits, is_lens