Exemplo n.º 1
0
 def __init__(self, args: argparse.Namespace):
     """Initialize a models, tokenizer and config."""
     super().__init__()
     self.args = args
     if isinstance(args, argparse.Namespace):
         self.save_hyperparameters(args)
     self.bert_dir = args.bert_path
     self.bert_config = BertConfig.from_pretrained(args.bert_path)
     self.model = GlyceBertForMaskedLM(self.bert_config)
     self.loss_fn = CrossEntropyLoss(reduction="none")
     self.acc = MaskedAccuracy(num_classes=self.bert_config.vocab_size)
Exemplo n.º 2
0
    def train(self, train_dataset, test_dataset, model):
        weights = make_weights_for_balanced_classes(train_dataset.targets)
        sampler = WeightedRandomSampler(weights, len(weights))
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=8,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            num_workers=8,
        )

        criterion_triplet = OnlineTripleLoss(
            margin=self.triplet_margin,
            sampling_strategy=self.triplet_sampling_strategy,
        )
        criterion_classifier = CrossEntropyLoss()

        optimizer_triplet = Adam(
            params=model.feature_extractor.parameters(),
            lr=self.learning_rate_triplet,
        )
        optimizer_classifier = Adam(
            params=model.classifier.parameters(),
            lr=self.learning_rate_classify,
        )
        print("Training with Triplet loss")
        for i in range(self.epochs_triplet):
            self._train_epoch_triplet(
                model,
                train_dataloader,
                optimizer_triplet,
                criterion_triplet,
                i + 1,
            )
            save_embedding_umap(model, train_dataloader, test_dataloader,
                                self.exp_folder, i + 1)
        print("Training the classifier")
        for i in range(self.epochs_classifier):
            self._train_epoch_classify(
                model,
                train_dataloader,
                optimizer_classifier,
                criterion_classifier,
                i + 1,
            )
            self._test_epoch_(model, test_dataloader, criterion_classifier,
                              i + 1)
 def __init__(self, args: argparse.Namespace):
     """Initialize a model, tokenizer and config."""
     super().__init__()
     self.args = args
     if isinstance(args, argparse.Namespace):
         self.save_hyperparameters(args)
     self.bert_dir = args.bert_path
     self.model = ExplainableModel(self.bert_dir)
     self.tokenizer = RobertaTokenizer.from_pretrained(self.bert_dir)
     self.loss_fn = CrossEntropyLoss()
     self.train_acc = pl.metrics.Accuracy()
     self.valid_acc = pl.metrics.Accuracy()
     self.output = []
     self.check_data = []
 def __init__(
     self,
     args: argparse.Namespace
 ):
     """Initialize a model, tokenizer and config."""
     super().__init__()
     self.args = args
     if isinstance(args, argparse.Namespace):
         self.save_hyperparameters(args)
     self.bert_dir = args.bert_path
     self.model = RobertaForSequenceClassification.from_pretrained(self.bert_dir)
     self.tokenizer = RobertaTokenizer.from_pretrained(self.bert_dir)
     self.loss_fn = CrossEntropyLoss()
     self.train_acc = pl.metrics.Accuracy()
     self.valid_acc = pl.metrics.Accuracy()
Exemplo n.º 5
0
    def __init__(self, args: argparse.Namespace):
        """Initialize a models, tokenizer and config."""
        super().__init__()
        self.args = args
        if isinstance(args, argparse.Namespace):
            self.save_hyperparameters(args)
        self.bert_dir = args.bert_path
        self.bert_config = BertConfig.from_pretrained(
            self.bert_dir, output_hidden_states=False)
        self.model = GlyceBertForSequenceClassification.from_pretrained(
            self.bert_dir)

        self.loss_fn = CrossEntropyLoss()
        self.acc = pl.metrics.Accuracy(num_classes=self.bert_config.num_labels)
        gpus_string = self.args.gpus if not self.args.gpus.endswith(
            ',') else self.args.gpus[:-1]
        self.num_gpus = len(gpus_string.split(","))
Exemplo n.º 6
0
 def __init__(self, args: argparse.Namespace):
     """Initialize a models, tokenizer and config."""
     super().__init__()
     self.args = args
     if isinstance(args, argparse.Namespace):
         self.save_hyperparameters(args)
     self.bert_dir = args.bert_path
     self.bert_config = BertConfig.from_pretrained(args.bert_path)
     if self.args.mode == 'glyce':
         self.model = GlyceBertForMaskedLM.from_pretrained(self.bert_dir)
     else:
         self.model = BertForMaskedLM.from_pretrained(self.bert_dir)
     self.loss_fn = CrossEntropyLoss(reduction="none")
     self.acc = MaskedAccuracy(num_classes=self.bert_config.vocab_size)
     gpus_string = self.args.gpus if not self.args.gpus.endswith(
         ',') else self.args.gpus[:-1]
     self.num_gpus = len(gpus_string.split(","))
Exemplo n.º 7
0
    def compute_loss(self, logits, labels):
        if self.loss_type == "ce":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "focal":
            loss_fct = FocalLoss(gamma=self.args.focal_gamma, reduction="mean")
            loss = loss_fct(logits.view(-1, 2), labels.view(-1))
        elif self.loss_type == "dice":
            loss_fct = DiceLoss(with_logits=True,
                                smooth=self.args.dice_smooth,
                                ohem_ratio=self.args.dice_ohem,
                                alpha=self.args.dice_alpha,
                                square_denominator=self.args.dice_square,
                                reduction="mean")
            loss = loss_fct(logits.view(-1, self.num_classes), labels)
        else:
            raise ValueError

        return loss
Exemplo n.º 8
0
 def __init__(self):
     super(ClassLoss, self).__init__()
     self.cross_entropy = CrossEntropyLoss(ignore_index=-100)
Exemplo n.º 9
0
def main():
    args = get_args()

    log.info(f'Parsed arguments: \n{pformat(args.__dict__)}')
    assert args.cond_type.lower() in ['none', 'platanios', 'oestling']

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    log.info('Using device {}.'.format(device))

    use_apex = False
    if torch.cuda.is_available() and args.fp16:
        log.info('Loading Nvidia Apex and using AMP')
        from apex import amp, optimizers
        use_apex = True
    else:
        log.info('Using FP32')
        amp = None

    log.info(f'Using time stamp {timestamp} to save models and logs.')

    if not args.no_seed:
        log.info(f'Setting random seed to {args.seed} for reproducibility.')
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    data = Corpus(args.datadir)

    data_splits = [
        {
            'split': 'train',
            'languages': args.dev_langs + args.target_langs,
            'invert_include': True,
        },
        {
            'split': 'valid',
            'languages': args.dev_langs,
        },
        {
            'split': 'test',
            'languages': args.target_langs,
        },
    ]

    if args.refine:
        data_splits.append({
            'split': 'train_100',
            'languages': args.target_langs,
            'ignore_missing': True,
        })

    data_splits = data.make_datasets(data_splits, force_rebuild=args.rebuild)
    train_set, val_set, test_set = data_splits['train'], data_splits[
        'valid'], data_splits['test']
    dictionary = data_splits['dictionary']

    train_language_distr = get_sampling_probabilities(train_set, 1.0)
    train_set = Dataset(train_set,
                        batchsize=args.batchsize,
                        bptt=args.bptt,
                        reset_on_iter=True,
                        language_probabilities=train_language_distr)
    val_set = Dataset(val_set,
                      make_config=True,
                      batchsize=args.valid_batchsize,
                      bptt=args.bptt,
                      eval=True)
    test_set = Dataset(test_set,
                       make_config=True,
                       batchsize=args.test_batchsize,
                       bptt=args.bptt,
                       eval=True)

    train_loader = DataLoader(train_set, num_workers=args.workers)
    val_loader = DataLoader(val_set, num_workers=args.workers)
    test_loader = DataLoader(test_set, num_workers=args.workers)

    if args.refine:
        refine_set = dict()
        for lang, lang_d in data_splits['train_100'].items():
            refine_set[lang] = Dataset({lang: lang_d},
                                       batchsize=args.valid_batchsize,
                                       bptt=args.bptt,
                                       make_config=True)

    n_token = len(dictionary.idx2tkn)

    # Load and preprocess matrix of typological features
    # TODO: implement this, the OEST
    # prior_matrix = load_prior(args.prior, corpus.dictionary.lang2idx)
    # n_components = min(50, *prior_matrix.shape)
    # pca = PCA(n_components=n_components, whiten=True)
    # prior_matrix = pca.fit_transform(prior_matrix)
    prior = None

    model = RNN(args.cond_type,
                prior,
                n_token,
                n_input=args.emsize,
                n_hidden=args.nhidden,
                n_layers=args.nlayers,
                dropout=args.dropouto,
                dropoute=args.dropoute,
                dropouth=args.dropouth,
                dropouti=args.dropouti,
                wdrop=args.wdrop,
                wdrop_layers=[0, 1, 2],
                tie_weights=True).to(device)

    if args.opt_level != 'O2':
        loss_function = SplitCrossEntropyLoss(args.emsize,
                                              splits=[]).to(device)
    else:
        loss_function = CrossEntropyLoss().to(
            device)  # Should be ok to use with a vocabulary of this small size

    if use_apex:
        optimizer = optimizers.FusedAdam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.wdecay)
    else:
        params = list(filter(lambda p: p.requires_grad,
                             model.parameters())) + list(
                                 loss_function.parameters())
        optimizer = Adam(params, lr=args.lr, weight_decay=args.wdecay)

    if use_apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)

    parameters = {
        'model': model,
        'optimizer': optimizer,
        'loss_function': loss_function,
        'use_apex': use_apex,
        'amp': amp if use_apex else None,
        'clip': args.clip,
        'alpha': args.alpha,
        'beta': args.beta,
        'bptt': args.bptt,
        'device': device,
        'prior': args.prior,
    }

    # Add backward hook for gradient clipping
    if args.clip:
        if use_apex:
            for p in amp.master_params(optimizer):
                p.register_hook(
                    lambda grad: torch.clamp(grad, -args.clip, args.clip))
        else:
            for p in model.parameters():
                p.register_hook(
                    lambda grad: torch.clamp(grad, -args.clip, args.clip))

    if args.prior == 'vi':
        prior = VIPrior(model, device=device)
        parameters['prior'] = prior

        def sample_weights(module: torch.nn.Module, input: torch.Tensor):
            prior.sample_weights(module)

        sample_weights_hook = model.register_forward_pre_hook(sample_weights)

    # Load model checkpoint if available
    start_epoch = 1
    if args.resume:
        if args.checkpoint is None:
            log.error(
                'No checkpoint passed. Specify it using the --checkpoint flag')
            checkpoint = None
        else:
            log.info('Loading the checkpoint at {}'.format(args.checkpoint))
            checkpoint = load_model(args.checkpoint, **parameters)

            start_epoch = checkpoint['epoch']

        if args.wdrop:
            for rnn in model.rnns:
                if isinstance(rnn, WeightDrop):
                    rnn.dropout = args.wdrop
                elif rnn.zoneout > 0:
                    rnn.zoneout = args.wdrop

    saved_models = list()

    result_str = '| Language {} | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'

    def test():
        log.info('=' * 89)
        log.info('Running test set (zero-shot results)...')
        test_loss, avg_loss = evaluate(test_loader, **parameters)
        log.info('Test set finished | test loss {} | test bpc {}'.format(
            test_loss, test_loss / math.log(2)))

        for lang, avg_l_loss in avg_loss.items():
            langstr = dictionary.idx2lang[lang]
            log.info(
                result_str.format(langstr, avg_l_loss, math.exp(avg_l_loss),
                                  avg_l_loss / math.log(2)))

        log.info('=' * 89)

    if args.train:
        f = 1.
        stored_loss = 1e32
        epochs_no_improve = 0

        val_losses = list()

        # calculate specific language lr
        data_spec_count = sum([len(ds) for l, ds in train_set.data.items()])
        data_spec_avg = data_spec_count / len(train_set.data.items())
        data_spec_lrweights = dict([(l, data_spec_avg / len(ds))
                                    for l, ds in train_set.data.items()])

        # estimate total number of steps
        total_steps = sum(
            [len(ds) // args.bptt
             for l, ds in train_set.data.items()]) * args.no_epochs
        steps = 0

        try:
            pbar = tqdm.trange(start_epoch,
                               args.no_epochs + 1,
                               position=1,
                               dynamic_ncols=True)
            for epoch in pbar:

                steps = train(train_loader,
                              lr_weights=data_spec_lrweights,
                              **parameters,
                              total_steps=total_steps,
                              steps=steps,
                              scaling=args.scaling,
                              n_samples=args.n_samples,
                              tb_writer=tb_writer)

                val_loss, _ = evaluate(val_loader, **parameters)
                pbar.set_description('Epoch {} | Val loss {}'.format(
                    epoch, val_loss))

                # Save model
                if args.prior == 'vi':
                    sample_weights_hook.remove()

                filename = path.join(
                    args.checkpoint_dir, '{}_epoch{}{}_{}.pth'.format(
                        timestamp, epoch, '_with_apex' if use_apex else '',
                        args.prior))
                torch.save(make_checkpoint(epoch + 1, **parameters), filename)
                saved_models.append(filename)

                if args.prior == 'vi':
                    sample_weights_hook = model.register_forward_pre_hook(
                        sample_weights)

                # Early stopping
                if val_loss < stored_loss:
                    epochs_no_improve = 0
                    stored_loss = val_loss
                else:
                    epochs_no_improve += 1

                if epochs_no_improve == args.patience:
                    log.info('Early stopping at epoch {}'.format(epoch))
                    break

                val_losses.append(val_loss)

                # Reduce lr every 1/3 total epochs
                if epoch - 1 > f / 3 * args.no_epochs:
                    log.info('Epoch {}/{}. Dividing LR by 10'.format(
                        epoch, args.no_epochs))
                    for g in optimizer.param_groups:
                        g['lr'] = g['lr'] / 10

                    f += 1.
            test()
        except KeyboardInterrupt:
            log.info('Registered KeyboardInterrupt. Stopping training.')
            log.info('Saving last model to disk')

            if args.prior == 'vi':
                sample_weights_hook.remove()

            torch.save(
                make_checkpoint(epoch, **parameters),
                path.join(
                    args.checkpoint_dir, '{}_epoch{}{}_{}.pth'.format(
                        timestamp, epoch, '_with_apex' if use_apex else '',
                        args.prior)))
            return
    elif args.test:
        test()

    # Only test on existing languages if there are no held out languages
    if not args.target_langs:
        exit(0)

    importance = 1e-5

    # If use UNIV, calculate informed prior, else use boring prior
    if args.prior == 'laplace':
        if not isinstance(
                prior,
                LaplacePrior):  # only calculate matrix if it is not supplied.
            log.info('Creating laplace approximation dataset')
            laplace_set = Dataset(data_splits['train'],
                                  batchsize=args.batchsize,
                                  bptt=100,
                                  reset_on_iter=True)
            laplace_loader = DataLoader(laplace_set, num_workers=args.workers)
            log.info('Creating Laplacian prior')
            prior = LaplacePrior(model,
                                 loss_function,
                                 laplace_loader,
                                 use_apex=use_apex,
                                 amp=amp,
                                 device=device)
            parameters['prior'] = prior

            torch.save(
                make_checkpoint('fisher_matrix', **parameters),
                path.join(
                    args.checkpoint_dir, '{}_fishers_matrix{}_{}.pth'.format(
                        timestamp, '_with_apex' if use_apex else '',
                        args.prior)))
        importance = 1e5

    elif args.prior == 'ninf':
        log.info('Creating non-informative Gaussian prior')
        parameters['prior'] = GaussianPrior()
    elif args.prior == 'vi':
        importance = 1e-5
    elif args.prior == 'hmc':
        raise NotImplementedError
    else:
        raise ValueError(
            f'Passed prior {args.prior} is not an implemented inference technique.'
        )

    best_model = saved_models[-1] if not len(
        saved_models) == 0 else args.checkpoint

    # Remove sampling hook from model
    if args.prior == 'vi':
        sample_weights_hook.remove()

    # Refine on 100 samples on each target
    if args.refine:
        # reset learning rate
        optimizer.param_groups[0]['lr'] = args.lr
        loss = 0

        results = dict()

        # Create individual tests sets
        test_sets = dict()
        for lang, lang_d in data_splits['test'].items():
            test_sets[lang] = DataLoader(Dataset({lang: lang_d},
                                                 make_config=True,
                                                 batchsize=args.test_batchsize,
                                                 bptt=args.bptt,
                                                 eval=True),
                                         num_workers=args.workers)

        for lang, lang_data in tqdm.tqdm(refine_set.items()):
            final_loss = False
            refine_dataloader = DataLoader(lang_data, num_workers=args.workers)
            load_model(best_model, **parameters)

            log.info(f'Refining for language {dictionary.idx2lang[lang]}')
            for epoch in range(1, args.refine_epochs + 1):
                refine(refine_dataloader, **parameters, importance=importance)
                if epoch % 5 == 0:
                    final_loss = True
                    loss, avg_loss = evaluate(test_sets[lang],
                                              model,
                                              loss_function,
                                              only_l=lang,
                                              report_all=True,
                                              device=device)

                    for lang, avg_l_loss in avg_loss.items():
                        langstr = dictionary.idx2lang[lang]
                        log.debug(
                            result_str.format(langstr, avg_l_loss,
                                              math.exp(avg_l_loss),
                                              avg_l_loss / math.log(2)))

            if not final_loss:
                loss, avg_loss = evaluate(test_sets[lang],
                                          model,
                                          loss_function,
                                          only_l=lang,
                                          report_all=True,
                                          device=device)

            for lang, avg_l_loss in avg_loss.items():
                langstr = dictionary.idx2lang[lang]
                log.info(
                    result_str.format(langstr, avg_l_loss,
                                      math.exp(avg_l_loss),
                                      avg_l_loss / math.log(2)))
                results[lang] = avg_l_loss

        log.info('=' * 89)
        log.info('FINAL FEW SHOT RESULTS: ')
        log.info('=' * 89)
        for lang, avg_l_loss in results.items():
            langstr = dictionary.idx2lang[lang]
            log.info(
                result_str.format(langstr, avg_l_loss, math.exp(avg_l_loss),
                                  avg_l_loss / math.log(2)))
        log.info('=' * 89)
Exemplo n.º 10
0
def train(model, SRC, TRG, MODEL_PATH, FORCE_MAX_LEN=50):
    model.train()
    optimizer = Adam(model.parameters(), lr=hp.LR, betas=(0.9, 0.98), eps=1e-9)
    criterion = CrossEntropyLoss(ignore_index=TRG.vocab.stoi["<pad>"])

    for epoch in tqdm(range(hp.EPOCHS)):

        for step, batch in enumerate(train_iter):
            global_step = epoch * len(train_iter) + step

            model.train()
            optimizer.zero_grad()
            optimizer = custom_lr_optimizer(optimizer, global_step)

            src = batch.src.T
            trg = batch.trg.T

            trg_input = trg[:, :-1]

            preds, _, _, _ = model(src, trg_input)
            ys = trg[:, 1:]

            loss = criterion(preds.permute(0, 2, 1), ys)
            loss.mean().backward()
            optimizer.step()

            if global_step % 50 == 0:
                print("#" * 90)

                rand_index = random.randrange(hp.BATCH_SIZE)

                model.eval()

                v = next(iter(val_iter))
                v_src, v_trg = v.src.T, v.trg.T

                v_trg_inp = v_trg[:, :-1].detach()
                v_trg_real = v_trg[:, 1:].detach()

                v_predictions, _, _, _ = model(v_src, v_trg_inp)
                max_args = v_predictions[rand_index].argmax(-1)
                print("For random element in VALIDATION batch (real/pred)...")
                print([
                    TRG.vocab.itos[word_idx]
                    for word_idx in v_trg_real[rand_index, :]
                ])
                print([TRG.vocab.itos[word_idx] for word_idx in max_args])

                print("Length til first <PAD> (real -> pred)...")
                try:
                    pred_PAD_idx = max_args.tolist().index(3)
                except:
                    pred_PAD_idx = None

                print(v_trg_real[rand_index, :].tolist().index(3), "  --->  ",
                      pred_PAD_idx)

                val_loss = criterion(v_predictions.permute(0, 2, 1),
                                     v_trg_real)
                print("TRAINING LOSS:", loss.mean().item())
                print("VALIDATION LOSS:", val_loss.mean().item())

                print("#" * 90)

                writer.add_scalar("Training Loss",
                                  loss.mean().detach().item(), global_step)
                writer.add_scalar("Validation Loss",
                                  val_loss.mean().detach().item(), global_step)
        torch.save(model, MODEL_PATH)
Exemplo n.º 11
0
 def __init__(self):
     super(Model, self).__init__(optimizer_class=optim.SGD,
                                 optimizer_kwargs=dict(lr=0.01,
                                                       momentum=0.5),
                                 criterion=CrossEntropyLoss())
     min_freq=2,
     init_token='<sos>',
     eos_token='<eos>')
 train_iter, valid_iter, test_iter = iterator_construction(
     train=train,
     valid=valid,
     test=test,
     batch_sizes=(hp.BATCH_SIZE, hp.BATCH_SIZE, hp.BATCH_SIZE),
     device=device)
 transformer_model = Transformer(len(src_field.vocab),
                                 len(trg_field.vocab),
                                 D=hp.D_MODEL,
                                 num_heads=hp.HEADS,
                                 D_ff=hp.D_FF,
                                 dropout=hp.P_DROP)
 criterion = CrossEntropyLoss(ignore_index=trg_field.vocab.stoi['<pad>'])
 # ignore_index specifies a target value that is ignored and does not contribute to the input gradient.
 training(
     transformer_model,
     epochs=hp.EPOCHS,
     train_iterator=train_iter,
     valid_iterator=valid_iter,
     optimizer=Adam(transformer_model.parameters(),
                    lr=hp.LR,
                    betas=(0.9, 0.98),
                    eps=1e-9),  # NEED TO CHANGE
     loss_fn=criterion,
     device=device,
     log_interval=3,
     save_model=False,
     model_path=
Exemplo n.º 13
0
    # train_loader = DataLoader(dataset=abideData_train, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    # test_loader = DataLoader(dataset=abideData_test, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    train_loader = DataLoader(dataset=abideData_train,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=abideData_test,
                             batch_size=batch_size,
                             shuffle=True)

    # 创建LSTM模型
    model = RNNModel(train_x[0].shape[1],
                     lstm_hidden_num,
                     lstm_output_num,
                     lstm_layers_num,
                     bidirectional=bidirectional).to(device)
    criterion = CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # 开启训练
    model.train()
    total_step = len(train_loader)
    # 初始化Hidden和Cell
    (hidden, cell) = model.init_hidden_cell(batch_size)
    for epoch in range(EPOCHS):
        for i, (data_x, data_y) in enumerate(train_loader):
            if data_x.shape[0] != batch_size:
                continue

            # 设置到GPU
            data_x = data_x.requires_grad_().to(device)
            data_y = data_y.to(device)
Exemplo n.º 14
0
test_df = pd.read_csv('../data/test_dataset.csv')

train_dataset = MyDataset(train_df,
                          image_folder_path='../data/images',
                          augmentations=get_training_augmentation())

test_dataset = MyDataset(test_df,
                         image_folder_path='../data/images',
                         augmentations=get_validation_augmentation())

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=1)

optimizer = Adam(model.parameters(), lr=config.learning_rate)

loss_function = CrossEntropyLoss()

for epoch in range(config.epochs):

    print(f'EPOCH {epoch}')
    print(f'-------------')
    """ TRAIN """
    print(f'TRAINING LOOP')

    losses = []
    model.train()

    for batch in tqdm(train_dataloader):

        images = batch['image'].to(device)
        targets = batch['class_id'].type(torch.int64).to(device)