def train(self, epochs):
     other_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'ner']
     mb = master_bar(range(epochs))
     mb.write(['epoch','loss','f1','precision','recall'],table=True)
     with self.nlp.disable_pipes(*other_pipes):  # only train NER
         for itn in mb:
             random.shuffle(self.data)
             losses = {}
             batches = minibatch(self.data, 
                                 size=compounding(4., 32., 1.001))
             for batch in progress_bar(list(batches), parent=mb):
                 texts, annotations = zip(*batch) 
                 # Updating the weights
                 self.nlp.update(texts, annotations, sgd=self.optimizer, 
                         drop=0.35, losses=losses)
             loss = losses.get('ner')
             metrics = self.nlp.evaluate(self.data[:len(self.data)//5])
             f1,precision,recall = metrics.ents_f,metrics.ents_p,metrics.ents_r
             itn = itn
             loss = round(loss,2) 
             f1 = round(f1,2)
             precision = round(precision,2)
             recall = round(recall,2)
             line = [str(itn), str(loss),str(f1),str(precision),str(recall)]
             mb.write(line,table=True)
    def fit(self, epochs: int, lr: float):
        "Main training loop"
        mb = master_bar(range(epochs))
        self.optimizer = self.prepare_optimizer(epochs, lr)
        # self.optimizer = self.opt_fn(lr=lr)
        # Loop over epochs
        mb.write(self.log_keys, table=True)
        exception = False
        st_time = time.time()
        try:
            for epoch in mb:
                train_loss, train_acc = self.train_epoch(mb)
                valid_loss, valid_acc = self.validate(mb)
                # print(f'{val_loss: 0.4f}', f'{eval_metric: 0.4f}')
                to_write = [epoch, train_loss,
                            train_acc, valid_loss, valid_acc]
                mb.write([str(stat) if isinstance(stat, int)
                          else f'{stat:.4f}' for stat in to_write], table=True)
                self.update_log_file(
                    good_format_stats(self.log_keys, to_write))
                if self.best_met < valid_acc:
                    self.best_met = valid_acc
                    self.save_model_dict()
        except Exception as e:
            exception = e
            raise e
        finally:
            end_time = time.time()
            self.update_log_file(
                f'epochs done {epoch}. Exited due to exception {exception}. Total time taken {end_time - st_time: 0.4f}')

            if self.best_met < valid_acc:
                self.save_model_dict()
Example #3
0
    def fit(self, train_dl, valid_dl, epochs, lr, metrics=None, optimizer=None, scheduler=None):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(device)
        optimizer = optimizer or Adam(self.model.parameters(), lr)
        if scheduler != False:
            scheduler = scheduler or OneCycleLR(optimizer, lr, epochs*len(train_dl))
        else:
            scheduler = None
        self.train_stats = TrainTracker(metrics, validate=(valid_dl is not None))
        bar = master_bar(range(epochs))
        bar.write(self.train_stats.metrics_names, table=True)

        for epoch in bar:
            self.model.train()
            for batch in progress_bar(train_dl, parent=bar):
                batch = batch_to_device(batch, device)
                loss = self._train_batch(batch, optimizer, scheduler)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if scheduler:
                    scheduler.step()
                self.train_stats.update_train_loss(loss)

            valid_outputs = []
            if valid_dl:
                self.model.eval()
                for batch in progress_bar(valid_dl, parent=bar):
                    batch = batch_to_device(batch, device)
                    output = self._valid_batch(batch)
                    valid_outputs.append(output)

            self.train_stats.log_epoch_results(valid_outputs)
            bar.write(self.train_stats.get_metrics_values(), table=True)
Example #4
0
    def _download_images(cat_list, path_images, max_images, remove_crowded):
        cat_ids = CocoData.coco.getCatIds(catNms=cat_list);
        idx2cat = {e['id']:e['name'] for e in CocoData.coco.loadCats(CocoData.coco.getCatIds())}
        img_id2fn = {}
        print(f"Found {len(cat_ids)} valid categories.")
        print([idx2cat[e] for e in cat_ids])
        print("Starting download.")
        mb = master_bar(range(len(cat_ids)))
        for i in mb:
            c_id = cat_ids[i]
            print(f"Downloading images of category {idx2cat[c_id]}")
            img_ids = CocoData.coco.getImgIds(catIds=c_id)
            # small function to filter images with crowded objects
            def _f(iid):
                annos = CocoData.coco.loadAnns(CocoData.coco.getAnnIds(imgIds=iid))
                annos = [a for a in annos if idx2cat[a["category_id"]] in cat_list]
                is_crowd = [a["iscrowd"] for a in annos]
                return 1 in is_crowd
            if remove_crowded:
                img_ids = [i for i in img_ids if not _f(i)]
            if max_images is not None:
                img_ids = img_ids[:max_images]
            for i in img_ids:
                img_id2fn[i] = path_images/(str(i).zfill(12)+".jpg")
            for i in progress_bar(range(len(img_ids)), parent=mb):
                with contextlib.redirect_stdout(io.StringIO()):
                    CocoData.coco.download(path_images, [img_ids[i]])

        print(len([fn for fn in path_images.ls()]), "images downloaded.")
Example #5
0
    def fit(self, epochs):
        self.logger.log_info(epochs, self.lr)
        mb = master_bar(range(epochs))
        for epoch in mb:
            self.model.train()
            for xb, yb in progress_bar(self.train_dl, parent=mb):
                loss = self.loss_func(self.model(xb), yb)
                loss.backward()
                self.opt.step()
                self.opt.zero_grad()

            self.model.eval()
            with torch.no_grad():
                tot_loss, tot_acc = 0., 0.
                for xb, yb in progress_bar(self.valid_dl, parent=mb):
                    pred = self.model(xb)
                    temp = self.loss_func(pred, yb)
                    tot_loss += temp
                    tot_acc += self.metric(pred,
                                           yb) if self.metric else 1 - temp
            nv = len(self.valid_dl)
            val_loss = tot_loss / nv
            acc = (tot_acc / nv) * 100.
            mb.write('Epoch: {:3}, train loss: {: .4f}, val loss: {: .4f}, '
                     'Acc: {: .4f}%'.format(epoch + 1, loss, val_loss, acc))
            self.logger.log([loss.cpu(), val_loss.cpu(), acc.cpu()])

        self.logger.done()
        io.save(self.model, self.logger.full_path)
Example #6
0
    def fit_supervised(self, epochs):
        mb = master_bar(range(epochs))
        for epoch in mb:
            self.model.train()
            for xb, yb in progress_bar(self.train_dl, parent=mb):
                mb.child.comment = 'Train loop'
                loss = self.loss_func(self.model(xb), yb)
                loss.backward()
                self.opt.step()
                self.opt.zero_grad()

            self.model.eval()
            with torch.no_grad():
                tot_loss, tot_acc = 0., 0.
                for xb, yb in progress_bar(self.valid_dl, parent=mb):
                    mb.child.comment = 'Valid loop'
                    pred = self.model(xb)
                    temp = self.loss_func(pred, yb)
                    tot_loss += temp
                    tot_acc += self.metric(pred,
                                           yb) if self.metric else 1 - temp
            nv = len(self.valid_dl)
            val_loss = tot_loss / nv
            acc = (tot_acc / nv) * 100.
            mb.write('Epoch: {:3}, train loss: {: .4f}, val loss: {: .4f}, '
                     'Acc: {: .4f}%'.format(epoch + 1, loss, val_loss, acc))
            self.logger.log([loss.cpu(), val_loss.cpu(), acc.cpu()])
def dtlz2_test():
    #Run the DTLZ2 benchmark
    errors = 0
    num_inputs = 6
    num_objectives = 2
    lab = DTLZ2(num_inputs=num_inputs, num_objectives=num_objectives)
    models = {
        f'y_{i}': GPyModel(Exponential(input_dim=num_inputs, ARD=True))
        for i in range(num_objectives)
    }

    warnings.filterwarnings("ignore", category=RuntimeWarning)
    tsemo = TSEMO(lab.domain, models=models, random_rate=0.00)
    experiments = tsemo.suggest_experiments(5 * num_inputs)

    mb = master_bar(range(1))
    for j in mb:
        mb.main_bar.comment = f'Repeats'
        for i in progress_bar(range(100), parent=mb):
            mb.child.comment = f'Iteration'
            # Run experiments
            experiments = lab.run_experiments(experiments)

            # Get suggestions
            try:
                experiments = tsemo.suggest_experiments(
                    1, experiments, **tsemo_options)
            except Exception as e:
                print(e)
                errors += 1

        tsemo.save(f'new_tsemo_params_{j}.json')
Example #8
0
    def fit_siamese(self, epochs):
        mb = master_bar(range(epochs))
        for epoch in mb:
            self.model.train()
            for x1b, x2b, rdm in progress_bar(self.train_dl, parent=mb):
                mb.child.comment = 'Train loop'
                out1 = self.model(x1b)
                out2 = self.model(x2b)
                loss = self.loss_func(out1, out2, rdm)
                loss.backward()
                self.opt.step()
                self.opt.zero_grad()

            self.model.eval()
            with torch.no_grad():
                tot_loss = 0.
                for x1b, x2b, rdm in progress_bar(self.valid_dl, parent=mb):
                    out1 = self.model(x1b)
                    out2 = self.model(x2b)
                    temp = self.loss_func(out1, out2, rdm)
                    tot_loss += temp
            nv = len(self.valid_dl)
            val_loss = tot_loss / nv
            mb.write(
                'Epoch: {}, train loss: {: .6f}, val loss: {: .6f}'.format(
                    epoch + 1, loss, val_loss))
            self.logger.log([loss.cpu(), val_loss.cpu()])
Example #9
0
    def fit(self, epochs, lr, validate=True, schedule_type="warmup_linear"):
        
        num_train_steps = int(len(self.data.train_dl) / self.grad_accumulation_steps * epochs)
        if self.optimizer is None:
            self.optimizer, self.schedule = self.get_optimizer(lr , num_train_steps)
        
        t_total = num_train_steps
        if self.multi_gpu == False:
            t_total = t_total #// torch.distributed.get_world_size()
            
        global_step = 0
        
        pbar = master_bar(range(epochs))
        
        for epoch in pbar:
            self.model.train()
  
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            
            for step, batch in enumerate(progress_bar(self.data.train_dl, parent=pbar)):
                batch = tuple(t.to(self.device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                
                if self.is_fp16 and self.multi_label:
                    label_ids = label_ids.half()
                
                loss = self.model(input_ids, segment_ids, input_mask, label_ids)
                if self.multi_gpu:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if self.grad_accumulation_steps > 1:
                    loss = loss / self.grad_accumulation_steps
                
                if self.is_fp16:
                    self.optimizer.backward(loss)
                else:
                    loss.backward()
                
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                
                if (step + 1) % self.grad_accumulation_steps == 0:
                    if self.is_fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = lr * self.schedule.get_lr(global_step)
                    
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    global_step += 1
                
            self.logger.info('Loss after epoch {} - {}'.format(epoch, tr_loss / nb_tr_steps))
#             logger.info('Eval after epoch  {}'.format(epoch))
        
            if validate:
                self.validate()
Example #10
0
    def fit(self, num_epochs, args, device='cuda:0'):
        """
        Fit the PyTorch model
        :param num_epochs: number of epochs to train (int)
        :param args:
        :param device: str (defaults to 'cuda:0')
        """
        optimizer, scheduler, step_scheduler_on_batch = self.optimizer(args)
        self.model = self.model.to(device)
        pbar = master_bar(range(num_epochs))
        headers = [
            'Train_Loss', 'Val_Loss', 'F1-Macro', 'F1-Micro', 'JS', 'Time'
        ]
        pbar.write(headers, table=True)
        for epoch in pbar:
            epoch += 1
            start_time = time.time()
            self.model.train()
            overall_training_loss = 0.0
            for step, batch in enumerate(
                    progress_bar(self.train_data_loader, parent=pbar)):
                loss, num_rows, _, _ = self.model(batch, device)
                overall_training_loss += loss.item() * num_rows

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_norm=1.0)
                optimizer.step()
                if step_scheduler_on_batch:
                    scheduler.step()
                optimizer.zero_grad()

            if not step_scheduler_on_batch:
                scheduler.step()

            overall_training_loss = overall_training_loss / len(
                self.train_data_loader.dataset)
            overall_val_loss, pred_dict = self.predict(device, pbar)
            y_true, y_pred = pred_dict['y_true'], pred_dict['y_pred']

            str_stats = []
            stats = [
                overall_training_loss, overall_val_loss,
                f1_score(y_true, y_pred, average="macro"),
                f1_score(y_true, y_pred, average="micro"),
                jaccard_score(y_true, y_pred, average="samples")
            ]

            for stat in stats:
                str_stats.append('NA' if stat is None else str(stat)
                                 if isinstance(stat, int) else f'{stat:.4f}')
            str_stats.append(format_time(time.time() - start_time))
            print('epoch#: ', epoch)
            pbar.write(str_stats, table=True)
            self.early_stop(overall_val_loss, self.model)
            if self.early_stop.early_stop:
                print("Early stopping")
                break
Example #11
0
def grid_search(agent_type, param_grid, env):
    """grid search for hyperparameter optimization"""
    param_keys, values = zip(*param_grid.items())

    param_combos = [dict(zip(param_keys, combo)) for combo in product(*values)]

    mb = master_bar(param_combos)

    for i, hyper_params in enumerate(mb):
        print(hyper_params)
        create_job(agent_type, hyper_params, env)
Example #12
0
    def fit(self,
            epochs,
            return_metric=False,
            monitor='epoch train_loss valid_loss time',
            model_path=os.path.join('..', 'weights', 'model.pth'),
            show_graph=True):
        self.model_path = model_path
        self.log(f'{time.ctime()}')
        self.log(f'Using device: {self.device}')
        mb = master_bar(range(1, epochs + 1))  #MAJOR
        mb.write(monitor.split(), table=True)

        model = self.model.to(self.device)
        optimizer = self.optimizer
        scheduler = self.scheduler
        best_metric = -np.inf
        train_loss_list, valid_loss_list = [], []

        for i_, epoch in enumerate(mb):
            epoch_start = timeit.default_timer()
            start = time.time()
            self.log('-' * 50)
            self.log(f'Running Epoch #{epoch} {"🔥"*epoch}')
            self.log(f'{"-"*50} \n')

            self.log('TRAINING...')
            train_loss = self.train(mb, model, optimizer, self.device,
                                    scheduler)
            train_loss_list.append(train_loss)  #for graph
            self.log(f'Training time: {round(time.time()-start, 2)} secs \n')

            start = time.time()
            self.log('EVALUATING...')
            valid_loss = self.validate(mb, model, self.device)
            valid_loss_list.append(valid_loss)  #for graph

            if show_graph:
                self.plot_loss_update(epoch, epochs, mb, train_loss_list,
                                      valid_loss_list)  # for graph

            epoch_end = timeit.default_timer()
            total_time = epoch_end - epoch_start
            mins, secs = divmod(total_time, 60)
            hours, mins = divmod(mins, 60)
            ret_time = f'{int(hours)}:{int(mins)}:{int(secs)}'
            mb.write([
                epoch, f'{train_loss:.6f}', f'{valid_loss:.6f}', f'{ret_time}'
            ],
                     table=True)
            self.log(f'Evaluation time: {ret_time}\n')
#             break

        if return_metric: return best_metric
Example #13
0
def get_vocabulary(smiles, augmentation=0, exclusive_tokens=False):
    """Read text and return dictionary that encodes vocabulary
    """
    print('Counting SMILES...')
    vocab = Counter()

    for i, smi in enumerate(smiles):
        vocab[smi] += 1

    print(f'{len(vocab)} unique Canonical SMILES')

    if augmentation > 0:
        print(f'Augmenting SMILES...({augmentation} times)')
        mb = master_bar(range(augmentation))
        for i in mb:
            for smi in progress_bar(smiles, parent=mb):
                randomized_smi = randomize_smiles(smi)
                vocab[randomized_smi] += 1

        print(f'{len(vocab)} unique SMILES (Canonical + Augmented)')
    return dict([(tuple(atomwise_tokenizer(x)), y)
                 for (x, y) in vocab.items()])
Example #14
0
def download(dl):
    mb = master_bar(range(10))
    download_file_path = download_file(dl["uri"], pbar=mb)

    if not os.path.isfile(download_file_path):
        logging.error("No file found at {}".format(download_file_path))
        shutil.rmtree(os.path.dirname(download_file_path))
        sys.exit(-1)

    if os.path.isdir(dl["path"]):
        shutil.rmtree(dl["path"])
        logging.info(f"Removed existing directory at {dl['path']}")

    with zipfile.ZipFile(download_file_path) as zf:
        zf.extractall(dl["path"])
        logging.info(f"Extracted zip to {dl['path']}")

    mi = has_mapinfo(dl["path"])
    if mi:
        to_shapefile(mi)

    shutil.rmtree(os.path.dirname(download_file_path))
    logging.info("Extracted to {}".format(dl["path"]))
Example #15
0
    def fit_one_cycle(self, dataloaders, model, optimizer, criterion, epochs, max_lr, device="cpu", path=None):
        """
        Implements OneCyclePolicy
        arguments:
            1. dataloaders : train and validation dataloaders as dict
                            {"train":train_dataloader,"validation":val_dataloader}
            2. model : user defined model
            3. optimizer : uder defined optimizer
            4. criterion : user defined loss function
            5. epochs : no. of epochs to train for
            6. max_lr : maximum learning rate the scheduler
                        can reach during one cycle training
            7. device : device on which to train [one of ("cpu" or "cuda:0")]
                        [default: "cpu"]
        """
        best_loss = np.Inf
        train_dl, val_dl = dataloaders["train"], dataloaders["validation"]
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=max_lr, steps_per_epoch=len(train_dl), epochs=epochs)
        mb = master_bar(range(epochs))
        for n in mb:
            # Training Step
            train_loss = train(model, train_dl, optimizer, criterion,
                               scheduler, device, mb, recorder=self.recorder)
            # Validaiton Step
            val_loss, val_acc, error_rate = validate(
                model, val_dl, criterion, device, mb)
            # Update Recoder
            self.recorder.update_loss_metrics(
                train_loss, val_loss, val_acc, error_rate)
            if path is not None:
                best_loss = save_model(model, best_loss, val_loss, path)
            mb.write(TEMPLATE.format(n, train_loss,
                                     val_loss, val_acc*100., error_rate))

        # Load best Model
        model.load_state_dict(torch.load(path))
def get_dataframe(nc4_list, master_progress_bar=None):
    global columns
    month_data = np.empty((0, len(columns)))
    # Loop over the files
    if master_progress_bar is None:
        master_progress_bar = master_bar([0])
        for _ in master_progress_bar:
            None

    for one_file in progress_bar(nc4_list, parent=master_progress_bar):
        np_table = get_np_table(one_file)
        month_data = np.concatenate((month_data, np_table), axis=0)

    if (month_data.size == 0):
        return pd.DataFrame(columns=columns)
    df = pd.DataFrame(month_data, columns=columns)
    # using dictionary to convert specific columns (https://www.geeksforgeeks.org/change-data-type-for-one-or-more-columns-in-pandas-dataframe/)
    convert_dict = {'sounding_id': int, 'orbit': int}
    df = df.astype(convert_dict)
    # Remove bad quality
    df = df[df['flag'] == 0]
    # Remove flag
    df.drop(['flag'], axis=1, inplace=True)
    return df
def process_files(input_dir, output_dir, patterns):
    '''
    Process all NC4 file corresponding to the patterns list.
    '''
    if len(patterns) < 1:
        raise Exception("ERROR You must give an array pattern !")
    master_progress_bar = master_bar(patterns)
    for pattern in master_progress_bar:
        # Get the file list in directory
        nc4_list = get_file_list(input_dir,
                                 pattern='oco2_LtCO2_' + pattern + "*.nc4")
        master_progress_bar.write(
            f'Files to process for {pattern} : {len(nc4_list)}')
        if len(nc4_list) > 1:
            #master_progress_bar.write(f'Loading {pattern}')
            df = get_dataframe(nc4_list, master_progress_bar)
            master_progress_bar.write(f'Saving {pattern} to disk...')
            df.to_csv(output_dir + 'oco2_' + pattern + '.csv.bz2',
                      sep=';',
                      index=False,
                      compression='bz2')
            del (df)
        else:
            master_progress_bar.write(f'WARNING : No file for {pattern}')
Example #18
0
def corpus_augment(infile, outdir, cycles):
    '''
    infile: line separated SMILES file
    outdir: directory to save the  augmented SMILE file.
        Each round of augmentation will save as a separated file, named as `infile_Ri`.
    cycles: number of rounds for SMILES augmentation
    '''
    if cycles <= 0:
        raise ValueError("Invalid option,  cycle should be larger than 0")

    with open(infile, "r") as ins:
        can_smiles = []
        for line in ins:
            can_smiles.append(line.split('\n')[0])

    fname = os.path.basename(infile).split('.')[0]
    ftype = os.path.basename(infile).split('.')[1]

    mb = master_bar(range(cycles))
    for i in mb:
        with open(f'{outdir}/{fname}_R{i}.{ftype}', 'a') as outfile:
            for smi in progress_bar(can_smiles, parent=mb):
                randomized_smi = randomize_smiles(smi)
                outfile.write(randomized_smi + '\n')
Example #19
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    torch.backends.cudnn.benchmark = True

    vocab = VOCABS[args.vocab]
    fonts = args.font.split(",")

    # Load val data generator
    st = time.time()
    if isinstance(args.val_path, str):
        with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
            val_hash = hashlib.sha256(f.read()).hexdigest()

        val_set = RecognitionDataset(
            img_folder=os.path.join(args.val_path, "images"),
            labels_path=os.path.join(args.val_path, "labels.json"),
            img_transforms=T.Resize((args.input_size, 4 * args.input_size),
                                    preserve_aspect_ratio=True),
        )
    else:
        val_hash = None
        # Load synthetic data generator
        val_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.val_samples * len(vocab),
            font_family=fonts,
            img_transforms=Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]),
        )

    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        sampler=SequentialSampler(val_set),
        pin_memory=torch.cuda.is_available(),
        collate_fn=val_set.collate_fn,
    )
    print(
        f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
        f"{len(val_loader)} batches)")

    batch_transforms = Normalize(mean=(0.694, 0.695, 0.693),
                                 std=(0.299, 0.296, 0.301))

    # Load doctr model
    model = recognition.__dict__[args.arch](pretrained=args.pretrained,
                                            vocab=vocab)

    # Resume weights
    if isinstance(args.resume, str):
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location="cpu")
        model.load_state_dict(checkpoint)

    # GPU
    if isinstance(args.device, int):
        if not torch.cuda.is_available():
            raise AssertionError(
                "PyTorch cannot access your GPU. Please investigate!")
        if args.device >= torch.cuda.device_count():
            raise ValueError("Invalid device index")
    # Silent default switch to GPU if available
    elif torch.cuda.is_available():
        args.device = 0
    else:
        logging.warning("No accessible GPU, targe device set to CPU.")
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
        model = model.cuda()

    # Metrics
    val_metric = TextMatch()

    if args.test_only:
        print("Running evaluation")
        val_loss, exact_match, partial_match = evaluate(model,
                                                        val_loader,
                                                        batch_transforms,
                                                        val_metric,
                                                        amp=args.amp)
        print(
            f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})"
        )
        return

    st = time.time()

    if isinstance(args.train_path, str):
        # Load train data generator
        base_path = Path(args.train_path)
        parts = ([base_path]
                 if base_path.joinpath("labels.json").is_file() else
                 [base_path.joinpath(sub) for sub in os.listdir(base_path)])
        with open(parts[0].joinpath("labels.json"), "rb") as f:
            train_hash = hashlib.sha256(f.read()).hexdigest()

        train_set = RecognitionDataset(
            parts[0].joinpath("images"),
            parts[0].joinpath("labels.json"),
            img_transforms=Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Augmentations
                T.RandomApply(T.ColorInversion(), 0.1),
                ColorJitter(brightness=0.3,
                            contrast=0.3,
                            saturation=0.3,
                            hue=0.02),
            ]),
        )
        if len(parts) > 1:
            for subfolder in parts[1:]:
                train_set.merge_dataset(
                    RecognitionDataset(subfolder.joinpath("images"),
                                       subfolder.joinpath("labels.json")))
    else:
        train_hash = None
        # Load synthetic data generator
        train_set = WordGenerator(
            vocab=vocab,
            min_chars=args.min_chars,
            max_chars=args.max_chars,
            num_samples=args.train_samples * len(vocab),
            font_family=fonts,
            img_transforms=Compose([
                T.Resize((args.input_size, 4 * args.input_size),
                         preserve_aspect_ratio=True),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
                ColorJitter(brightness=0.3,
                            contrast=0.3,
                            saturation=0.3,
                            hue=0.02),
            ]),
        )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.workers,
        sampler=RandomSampler(train_set),
        pin_memory=torch.cuda.is_available(),
        collate_fn=train_set.collate_fn,
    )
    print(
        f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
        f"{len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    # Optimizer
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        args.lr,
        betas=(0.95, 0.99),
        eps=1e-6,
        weight_decay=args.weight_decay,
    )
    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model,
                                train_loader,
                                batch_transforms,
                                optimizer,
                                amp=args.amp)
        plot_recorder(lrs, losses)
        return
    # Scheduler
    if args.sched == "cosine":
        scheduler = CosineAnnealingLR(optimizer,
                                      args.epochs * len(train_loader),
                                      eta_min=args.lr / 25e4)
    elif args.sched == "onecycle":
        scheduler = OneCycleLR(optimizer, args.lr,
                               args.epochs * len(train_loader))

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="text-recognition",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": args.weight_decay,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "pytorch",
                "scheduler": args.sched,
                "vocab": args.vocab,
                "train_hash": train_hash,
                "val_hash": val_hash,
                "pretrained": args.pretrained,
            },
        )

    # Create loss queue
    min_loss = np.inf
    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model,
                      train_loader,
                      batch_transforms,
                      optimizer,
                      scheduler,
                      mb,
                      amp=args.amp)

        # Validation loop at the end of each epoch
        val_loss, exact_match, partial_match = evaluate(model,
                                                        val_loader,
                                                        batch_transforms,
                                                        val_metric,
                                                        amp=args.amp)
        if val_loss < min_loss:
            print(
                f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state..."
            )
            torch.save(model.state_dict(), f"./{exp_name}.pt")
            min_loss = val_loss
        mb.write(
            f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "
            f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})")
        # W&B
        if args.wb:
            wandb.log({
                "val_loss": val_loss,
                "exact_match": exact_match,
                "partial_match": partial_match,
            })

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="recognition", run_config=args)
Example #20
0
 def begin_fit(self):
     self.mbar = master_bar(range(self.epochs))
     self.mbar.on_iter_begin()
     self.trainer.logger = partial(self.mbar.write, table=True)
Example #21
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Train batch size per GPU = %d", args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0

    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    # Added here for reproductibility
    set_seed(args)

    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "distilkobert",
                    "xlm-roberta"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })
            # reforbert인 경우
            if args.model_type in ["reforbert"]:
                del inputs["attention_mask"]

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.evaluate_during_training:
                        logger.info("***** Eval results *****")
                        results = evaluate(args,
                                           model,
                                           tokenizer,
                                           global_step=global_step)
                        for key in sorted(results.keys()):
                            logger.info("  %s = %s", key, str(results[key]))

                    logging_loss = tr_loss

                # Save model checkpoint
                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    if args.save_optimizer:
                        torch.save(optimizer.state_dict(),
                                   os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(),
                                   os.path.join(output_dir, "scheduler.pt"))
                        logger.info(
                            "Saving optimizer and scheduler states to %s",
                            output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                break

        mb.write("Epoch {} done".format(epoch + 1))

        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step
Example #22
0
def main():

    logging.basicConfig(format='%(asctime)s - %(levelname)s -   %(message)s',
                        datefmt='%m/%d/%Y ',
                        level=logging.INFO)
    logger = logging.getLogger(__name__)

    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument("--data",
                        default=None,
                        type=str,
                        required=True,
                        help="Directory which has the data files for the task")
    parser.add_argument(
        "--output",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--overwrite",
                        default=False,
                        type=bool,
                        help="Set it to True to overwrite output directory")

    args = parser.parse_args()

    if os.path.exists(args.output) and os.listdir(
            args.output) and not args.overwrite:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Set the overwrite flag to overwrite"
            .format(args.output))
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    train_batch_size = 32
    valid_batch_size = 64
    test_batch_size = 64

    # padding sentences and labels to max_length of 128
    max_seq_len = 128
    EMBEDDING_DIM = 100
    epochs = 10

    split_train = split_text_label(os.path.join(args.data, "train.txt"))
    split_valid = split_text_label(os.path.join(args.data, "valid.txt"))
    split_test = split_text_label(os.path.join(args.data, "test.txt"))

    labelSet = set()
    wordSet = set()
    # words and labels
    for data in [split_train, split_valid, split_test]:
        for labeled_text in data:
            for word, label in labeled_text:
                labelSet.add(label)
                wordSet.add(word.lower())

    # Sort the set to ensure '0' is assigned to 0
    sorted_labels = sorted(list(labelSet), key=len)

    # Create mapping for labels
    label2Idx = {}
    for label in sorted_labels:
        label2Idx[label] = len(label2Idx)

    num_labels = len(label2Idx)
    idx2Label = {v: k for k, v in label2Idx.items()}

    pickle.dump(idx2Label,
                open(os.path.join(args.output, "idx2Label.pkl"), 'wb'))
    logger.info("Saved idx2Label pickle file")

    # Create mapping for words
    word2Idx = {}
    if len(word2Idx) == 0:
        word2Idx["PADDING_TOKEN"] = len(word2Idx)
        word2Idx["UNKNOWN_TOKEN"] = len(word2Idx)
    for word in wordSet:
        word2Idx[word] = len(word2Idx)
    logger.info("Total number of words is : %d ", len(word2Idx))

    pickle.dump(word2Idx, open(os.path.join(args.output, "word2Idx.pkl"),
                               'wb'))
    logger.info("Saved word2Idx pickle file")

    # Loading glove embeddings
    embeddings_index = {}
    f = open('embeddings/glove.6B.100d.txt', encoding="utf-8")
    for line in f:
        values = line.strip().split(' ')
        word = values[0]  # the first entry is the word
        coefs = np.asarray(
            values[1:], dtype='float32')  #100d vectors representing the word
        embeddings_index[word] = coefs
    f.close()
    logger.info("Glove data loaded")

    #print(str(dict(itertools.islice(embeddings_index.items(), 2))))

    embedding_matrix = np.zeros((len(word2Idx), EMBEDDING_DIM))

    # Word embeddings for the tokens
    for word, i in word2Idx.items():
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector

    pickle.dump(embedding_matrix,
                open(os.path.join(args.output, "embedding.pkl"), 'wb'))
    logger.info("Saved Embedding matrix pickle")

    # Interesting - to check how many words were not there in Glove Embedding
    # indices = np.where(np.all(np.isclose(embedding_matrix, 0), axis=1))
    # print(len(indices[0]))

    train_sentences, train_labels = createMatrices(split_train, word2Idx,
                                                   label2Idx)
    valid_sentences, valid_labels = createMatrices(split_valid, word2Idx,
                                                   label2Idx)
    test_sentences, test_labels = createMatrices(split_test, word2Idx,
                                                 label2Idx)

    train_features, train_labels = padding(train_sentences,
                                           train_labels,
                                           max_seq_len,
                                           padding='post')
    valid_features, valid_labels = padding(valid_sentences,
                                           valid_labels,
                                           max_seq_len,
                                           padding='post')
    test_features, test_labels = padding(test_sentences,
                                         test_labels,
                                         max_seq_len,
                                         padding='post')

    logger.info(
        f"Train features shape is {train_features.shape} and labels shape is{train_labels.shape}"
    )
    logger.info(
        f"Valid features shape is {valid_features.shape} and labels shape is{valid_labels.shape}"
    )
    logger.info(
        f"Test features shape is {test_features.shape} and labels shape is{test_labels.shape}"
    )

    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_features, train_labels))
    valid_dataset = tf.data.Dataset.from_tensor_slices(
        (valid_features, valid_labels))
    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_features, test_labels))

    shuffled_train_dataset = train_dataset.shuffle(
        buffer_size=train_features.shape[0], reshuffle_each_iteration=True)

    batched_train_dataset = shuffled_train_dataset.batch(train_batch_size,
                                                         drop_remainder=True)
    batched_valid_dataset = valid_dataset.batch(valid_batch_size,
                                                drop_remainder=True)
    batched_test_dataset = test_dataset.batch(test_batch_size,
                                              drop_remainder=True)

    epoch_bar = master_bar(range(epochs))
    train_pb_max_len = math.ceil(
        float(len(train_features)) / float(train_batch_size))
    valid_pb_max_len = math.ceil(
        float(len(valid_features)) / float(valid_batch_size))
    test_pb_max_len = math.ceil(
        float(len(test_features)) / float(test_batch_size))

    model = TFNer(max_seq_len=max_seq_len,
                  embed_input_dim=len(word2Idx),
                  embed_output_dim=EMBEDDING_DIM,
                  weights=[embedding_matrix],
                  num_labels=num_labels)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
    scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    train_log_dir = f"{args.output}/logs/train"
    valid_log_dir = f"{args.output}/logs/valid"
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    valid_summary_writer = tf.summary.create_file_writer(valid_log_dir)

    train_loss_metric = tf.keras.metrics.Mean('training_loss',
                                              dtype=tf.float32)
    valid_loss_metric = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32)

    def train_step_fn(sentences_batch, labels_batch):
        with tf.GradientTape() as tape:
            logits = model(
                sentences_batch)  # batchsize, max_seq_len, num_labels
            loss = scce(labels_batch, logits)  #batchsize,max_seq_len
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
        return loss, logits

    def valid_step_fn(sentences_batch, labels_batch):
        logits = model(sentences_batch)
        loss = scce(labels_batch, logits)
        return loss, logits

    for epoch in epoch_bar:
        with train_summary_writer.as_default():
            for sentences_batch, labels_batch in progress_bar(
                    batched_train_dataset,
                    total=train_pb_max_len,
                    parent=epoch_bar):

                loss, logits = train_step_fn(sentences_batch, labels_batch)
                train_loss_metric(loss)
                epoch_bar.child.comment = f'training loss : {train_loss_metric.result()}'
            tf.summary.scalar('training loss',
                              train_loss_metric.result(),
                              step=epoch)
            train_loss_metric.reset_states()

        with valid_summary_writer.as_default():
            for sentences_batch, labels_batch in progress_bar(
                    batched_valid_dataset,
                    total=valid_pb_max_len,
                    parent=epoch_bar):
                loss, logits = valid_step_fn(sentences_batch, labels_batch)
                valid_loss_metric.update_state(loss)

                epoch_bar.child.comment = f'validation loss : {valid_loss_metric.result()}'

            # Logging after each Epoch !
            tf.summary.scalar('valid loss',
                              valid_loss_metric.result(),
                              step=epoch)
            valid_loss_metric.reset_states()

    model.save_weights(f"{args.output}/model_weights", save_format='tf')
    logger.info(f"Model weights saved")

    #Evaluating on test dataset

    test_model = TFNer(max_seq_len=max_seq_len,
                       embed_input_dim=len(word2Idx),
                       embed_output_dim=EMBEDDING_DIM,
                       weights=[embedding_matrix],
                       num_labels=num_labels)
    test_model.load_weights(f"{args.output}/model_weights")
    logger.info(f"Model weights restored")

    true_labels = []
    pred_labels = []

    for sentences_batch, labels_batch in progress_bar(batched_test_dataset,
                                                      total=test_pb_max_len):

        logits = test_model(sentences_batch)
        temp1 = tf.nn.softmax(logits)
        preds = tf.argmax(temp1, axis=2)
        true_labels.append(np.asarray(labels_batch))
        pred_labels.append(np.asarray(preds))

    label_correct, label_pred = idx_to_label(pred_labels, true_labels,
                                             idx2Label)
    report = classification_report(label_correct, label_pred, digits=4)
    logger.info(f"Results for the test dataset")
    logger.info(f"\n{report}")
Example #23
0
    def fit(self,
            epochs,
            lr,
            validate=True,
            schedule_type="warmup_cosine",
            optimizer_type='lamb'):

        tensorboard_dir = self.output_dir / 'tensorboard'
        tensorboard_dir.mkdir(exist_ok=True)
        print(tensorboard_dir)

        # Train the model
        tb_writer = SummaryWriter(tensorboard_dir)

        train_dataloader = self.data.train_dl
        if self.max_steps > 0:
            t_total = self.max_steps
            self.epochs = self.max_steps // len(
                train_dataloader) // self.grad_accumulation_steps + 1
        else:
            t_total = len(
                train_dataloader) // self.grad_accumulation_steps * epochs

        # Prepare optimiser and schedule
        optimizer, _ = self.get_optimizer(lr,
                                          t_total,
                                          schedule_type=schedule_type,
                                          optimizer_type=optimizer_type)

        # get the base model if its already wrapped around DataParallel
        if hasattr(self.model, 'module'):
            self.model = self.model.module

        if self.is_fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError('Please install apex to use fp16 training')
            self.model, optimizer = amp.initialize(
                self.model, optimizer, opt_level=self.fp16_opt_level)

        schedule_class = SCHEDULES[schedule_type]

        scheduler = schedule_class(optimizer,
                                   warmup_steps=self.warmup_steps,
                                   t_total=t_total)

        # Parallelize the model architecture
        if self.multi_gpu == True:
            self.model = torch.nn.DataParallel(self.model)

        # Start Training
        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d", len(train_dataloader.dataset))
        self.logger.info("  Num Epochs = %d", epochs)
        self.logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.data.train_batch_size * self.grad_accumulation_steps)
        self.logger.info("  Gradient Accumulation steps = %d",
                         self.grad_accumulation_steps)
        self.logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epoch_step = 0
        tr_loss, logging_loss, epoch_loss = 0.0, 0.0, 0.0
        self.model.zero_grad()
        pbar = master_bar(range(epochs))

        for epoch in pbar:
            epoch_step = 0
            epoch_loss = 0.0
            for step, batch in enumerate(
                    progress_bar(train_dataloader, parent=pbar)):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3]
                }

                if self.model_type in ['bert', 'xlnet']:
                    inputs['token_type_ids'] = batch[2]

                outputs = self.model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if self.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if self.grad_accumulation_steps > 1:
                    loss = loss / self.grad_accumulation_steps

                if self.is_fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), self.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_grad_norm)

                tr_loss += loss.item()
                epoch_loss += loss.item()
                if (step + 1) % self.grad_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()

                    self.model.zero_grad()
                    global_step += 1
                    epoch_step += 1

                    if self.logging_steps > 0 and global_step % self.logging_steps == 0:
                        if validate:
                            # evaluate model
                            results = self.validate()
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key),
                                                     value, global_step)
                                self.logger.info(
                                    "eval_{} after step {}: {}: ".format(
                                        key, global_step, value))

                        # Log metrics
                        self.logger.info("lr after step {}: {}".format(
                            global_step,
                            scheduler.get_lr()[0]))
                        self.logger.info("train_loss after step {}: {}".format(
                            global_step,
                            (tr_loss - logging_loss) / self.logging_steps))
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             self.logging_steps, global_step)

                        logging_loss = tr_loss

            # Evaluate the model after every epoch
            if validate:
                results = self.validate()
                for key, value in results.items():
                    self.logger.info("eval_{} after epoch {}: {}: ".format(
                        key, (epoch + 1), value))

            # Log metrics
            self.logger.info("lr after epoch {}: {}".format(
                (epoch + 1),
                scheduler.get_lr()[0]))
            self.logger.info("train_loss after epoch {}: {}".format(
                (epoch + 1), epoch_loss / epoch_step))
            self.logger.info("\n")

        tb_writer.close()
        return global_step, tr_loss / global_step
Example #24
0
    def fit_old(self,
                epochs,
                lr,
                validate=True,
                schedule_type="warmup_linear"):

        if self.is_fp16:
            self.model = self.model.half()

        # Parallelize the model architecture
        if self.multi_gpu == False:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex distributed and fp16 training.")

            self.model = DDP(self.model)
        else:
            self.model = torch.nn.DataParallel(self.model)

        num_train_steps = int(
            len(self.data.train_dl) / self.grad_accumulation_steps * epochs)
        if self.optimizer is None:
            self.optimizer, self.schedule = self.get_optimizer_old(
                lr, num_train_steps)

        t_total = num_train_steps
        if self.multi_gpu == False:
            t_total = t_total // torch.distributed.get_world_size()

        global_step = 0

        pbar = master_bar(range(epochs))
        tb_writer = SummaryWriter()

        logging_loss = 0.0
        tr_loss = 0.0

        for epoch in pbar:
            self.model.train()

            epoch_tr_loss = 0.0
            nb_tr_examples, nb_tr_steps = 0, 0

            for step, batch in enumerate(
                    progress_bar(self.data.train_dl, parent=pbar)):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                if self.is_fp16 and self.multi_label:
                    label_ids = label_ids.half()

                outputs = self.model(input_ids, segment_ids, input_mask,
                                     label_ids)
                loss = outputs[0]

                if self.multi_gpu:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if self.grad_accumulation_steps > 1:
                    loss = loss / self.grad_accumulation_steps

                if self.is_fp16:
                    self.optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                epoch_tr_loss += loss.item()

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if (step + 1) % self.grad_accumulation_steps == 0:
                    lr_this_step = lr * self.schedule.get_lr(global_step)
                    if self.is_fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically

                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    global_step += 1

                    if self.logging_steps > 0 and (global_step %
                                                   self.logging_steps == 0):
                        self.logger.info(
                            'Loss after global step {} - {}'.format(
                                global_step,
                                (tr_loss - logging_loss) / self.logging_steps))
                        self.logger.info('LR after global step {} - {}'.format(
                            global_step, lr_this_step))

                        tb_writer.add_scalar('lr', lr_this_step, global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             self.logging_steps, global_step)
                        logging_loss = tr_loss

            self.logger.info('Loss after epoch {} - {}'.format(
                (epoch + 1), epoch_tr_loss / nb_tr_steps))
            #             logger.info('Eval after epoch  {}'.format(epoch))

            if validate:
                self.validate()

        tb_writer.close()
Example #25
0
    def fit(self, epochs, lr, validate=True, schedule_type="warmup_linear"):

        if self.use_amp_optimizer == False:
            self.fit_old(epochs,
                         lr,
                         validate=validate,
                         schedule_type=schedule_type)
            return

        num_train_steps = int(
            (len(self.data.train_dl) / self.grad_accumulation_steps) * epochs)

        if self.optimizer is None:
            self.optimizer, self.schedule = self.get_optimizer(
                lr, num_train_steps)

        if self.is_fp16:
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=self.fp16_opt_level)

        # Parallelize the model architecture
        if self.multi_gpu == True:
            self.model = torch.nn.DataParallel(self.model)

        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d",
                         len(self.data.train_dl.dataset))
        self.logger.info("  Num Epochs = %d", epochs)

        t_total = num_train_steps
        if self.multi_gpu == False:
            t_total = t_total // torch.distributed.get_world_size()

        self.logger.info("  Gradient Accumulation steps = %d",
                         self.grad_accumulation_steps)
        self.logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        self.model.zero_grad()

        pbar = master_bar(range(epochs))
        tb_writer = SummaryWriter()

        for epoch in pbar:

            nb_tr_examples, nb_tr_steps = 0, 0
            epoch_tr_loss = 0.0

            for step, batch in enumerate(
                    progress_bar(self.data.train_dl, parent=pbar)):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'token_type_ids': batch[2],
                    'labels': batch[3]
                }

                outputs = self.model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if self.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training

                if self.grad_accumulation_steps > 1:
                    loss = loss / self.grad_accumulation_steps

                if self.is_fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), self.max_grad_norm)

                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_grad_norm)

                tr_loss += loss.item()
                epoch_tr_loss += loss.item()

                if (step + 1) % self.grad_accumulation_steps == 0:
                    self.schedule.step()  # Update learning rate schedule
                    self.optimizer.step()
                    self.model.zero_grad()
                    global_step += 1
                    nb_tr_steps += 1

                    if self.logging_steps > 0 and (global_step %
                                                   self.logging_steps == 0):
                        self.logger.info(
                            'Loss after global step {} - {}'.format(
                                global_step,
                                (tr_loss - logging_loss) / self.logging_steps))
                        self.logger.info('LR after global step {} - {}'.format(
                            global_step,
                            self.schedule.get_lr()[0]))
                        tb_writer.add_scalar('lr',
                                             self.schedule.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             self.logging_steps, global_step)
                        logging_loss = tr_loss

            self.logger.info('Loss after epoch {} - {}'.format(
                epoch, epoch_tr_loss / nb_tr_steps))

            if validate:
                self.validate()

        tb_writer.close()
Example #26
0
def main(args):

    print(args)

    if args.push_to_hub:
        login_to_hub()

    if not isinstance(args.workers, int):
        args.workers = min(16, mp.cpu_count())

    torch.backends.cudnn.benchmark = True

    vocab = VOCABS[args.vocab]

    fonts = args.font.split(",")

    # Load val data generator
    st = time.time()
    val_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.val_samples * len(vocab),
        cache_samples=True,
        img_transforms=Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Ensure we have a 90% split of white-background images
                T.RandomApply(T.ColorInversion(), 0.9),
            ]
        ),
        font_family=fonts,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        sampler=SequentialSampler(val_set),
        pin_memory=torch.cuda.is_available(),
    )
    print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)")

    batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

    # Load doctr model
    model = classification.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab), classes=list(vocab))

    # Resume weights
    if isinstance(args.resume, str):
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location="cpu")
        model.load_state_dict(checkpoint)

    # GPU
    if isinstance(args.device, int):
        if not torch.cuda.is_available():
            raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
        if args.device >= torch.cuda.device_count():
            raise ValueError("Invalid device index")
    # Silent default switch to GPU if available
    elif torch.cuda.is_available():
        args.device = 0
    else:
        logging.warning("No accessible GPU, targe device set to CPU.")
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
        model = model.cuda()

    if args.test_only:
        print("Running evaluation")
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        return

    st = time.time()

    # Load train data generator
    train_set = CharacterGenerator(
        vocab=vocab,
        num_samples=args.train_samples * len(vocab),
        cache_samples=True,
        img_transforms=Compose(
            [
                T.Resize((args.input_size, args.input_size)),
                # Augmentations
                T.RandomApply(T.ColorInversion(), 0.9),
                # GaussianNoise
                T.RandomApply(Grayscale(3), 0.1),
                ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
                T.RandomApply(GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)), 0.3),
                RandomRotation(15, interpolation=InterpolationMode.BILINEAR),
            ]
        ),
        font_family=fonts,
    )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.workers,
        sampler=RandomSampler(train_set),
        pin_memory=torch.cuda.is_available(),
    )
    print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, list(map(vocab.__getitem__, target)))
        return

    # Optimizer
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        args.lr,
        betas=(0.95, 0.99),
        eps=1e-6,
        weight_decay=args.weight_decay,
    )

    # LR Finder
    if args.find_lr:
        lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp)
        plot_recorder(lrs, losses)
        return
    # Scheduler
    if args.sched == "cosine":
        scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
    elif args.sched == "onecycle":
        scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(
            name=exp_name,
            project="character-classification",
            config={
                "learning_rate": args.lr,
                "epochs": args.epochs,
                "weight_decay": args.weight_decay,
                "batch_size": args.batch_size,
                "architecture": args.arch,
                "input_size": args.input_size,
                "optimizer": "adam",
                "framework": "pytorch",
                "vocab": args.vocab,
                "scheduler": args.sched,
                "pretrained": args.pretrained,
            },
        )

    # Create loss queue
    min_loss = np.inf
    # Training loop
    mb = master_bar(range(args.epochs))
    for epoch in mb:
        fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)

        # Validation loop at the end of each epoch
        val_loss, acc = evaluate(model, val_loader, batch_transforms)
        if val_loss < min_loss:
            print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
            torch.save(model.state_dict(), f"./{exp_name}.pt")
            min_loss = val_loss
        mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})")
        # W&B
        if args.wb:
            wandb.log(
                {
                    "val_loss": val_loss,
                    "acc": acc,
                }
            )

    if args.wb:
        run.finish()

    if args.push_to_hub:
        push_to_hf_hub(model, exp_name, task="classification", run_config=args)

    if args.export_onnx:
        print("Exporting model to ONNX...")
        dummy_batch = next(iter(val_loader))
        dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
        model_path = export_model_to_onnx(model, exp_name, dummy_input)
        print(f"Exported model saved in {model_path}")
Example #27
0
    def fit(self, epochs: int, lr: float,
            params_opt_dict: Optional[Dict] = None):
        "Main training loop"
        # Print logger at the start of the training loop
        self.logger.info(self.cfg)
        # Initialize the progress_bar
        mb = master_bar(range(epochs))
        # Initialize optimizer
        # Prepare Optimizer may need to be re-written as per use
        self.optimizer = self.prepare_optimizer(params_opt_dict)
        # Initialize scheduler
        # Prepare scheduler may need to re-written as per use
        self.lr_scheduler = self.prepare_scheduler(self.optimizer)

        # Write the top row display
        # mb.write(self.log_keys, table=True)
        self.master_bar_write(mb, line=self.log_keys, table=True)
        exception = False
        met_to_use = None
        # Keep record of time until exit
        st_time = time.time()
        try:
            # Loop over epochs
            for epoch in mb:
                self.num_epoch += 1
                train_loss, train_acc = self.train_epoch(mb)

                valid_loss, valid_acc, predictions = self.validate(
                    self.data.valid_dl, mb)

                valid_acc_to_use = valid_acc[self.met_keys[0]]
                # Depending on type
                self.scheduler_step(valid_acc_to_use)

                # Now only need main process
                # Decide to save or not
                met_to_use = valid_acc[self.met_keys[0]].cpu()
                if self.best_met < met_to_use:
                    self.best_met = met_to_use
                    self.save_model_dict()
                    self.update_prediction_file(
                        predictions,
                        self.predictions_dir / f'val_preds_{self.uid}.pkl')

                # Prepare what all to write
                to_write = self.prepare_to_write(
                    train_loss, train_acc,
                    valid_loss, valid_acc
                )

                # Display on terminal
                assert to_write is not None
                mb_write = [str(stat) if isinstance(stat, int)
                            else f'{stat:.4f}' for stat in to_write]
                self.master_bar_write(mb, line=mb_write, table=True)

                # for k, record in zip(self.log_keys, to_write):
                #     self.writer.add_scalar(
                #         tag=k, scalar_value=record, global_step=self.num_epoch)
                # Update in the log file
                self.update_log_file(
                    good_format_stats(self.log_keys, to_write))

        except Exception as e:
            exception = e
            raise e
        finally:
            end_time = time.time()
            self.update_log_file(
                f'epochs done {epoch}. Exited due to exception {exception}. '
                f'Total time taken {end_time - st_time: 0.4f}\n\n'
            )
            # Decide to save finally or not
            if met_to_use:
                if self.best_met < met_to_use:
                    self.save_model_dict()
Example #28
0
def train(args, model, train_dataset, dev_dataset=None, test_dataset=None):
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(t_total * args.warmup_proportion),
        num_training_steps=t_total)

    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)
    logger.info("  Save steps = %d", args.save_steps)

    global_step = 0
    tr_loss = 0.0

    model.zero_grad()
    mb = master_bar(range(int(args.num_train_epochs)))
    for epoch in mb:
        epoch_iterator = progress_bar(train_dataloader, parent=mb)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type not in ["distilkobert", "xlm-roberta"]:
                inputs["token_type_ids"] = batch[
                    2]  # Distilkobert, XLM-Roberta don't use segment_ids
            outputs = model(**inputs)

            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    len(train_dataloader) <= args.gradient_accumulation_steps
                    and (step + 1) == len(train_dataloader)):
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if args.evaluate_test_during_training:
                        evaluate(args, model, test_dataset, "test",
                                 global_step)
                    else:
                        evaluate(args, model, dev_dataset, "dev", global_step)

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (model.module
                                     if hasattr(model, "module") else model)
                    model_to_save.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info(
                        "Saving model checkpoint to {}".format(output_dir))

                    if args.save_optimizer:
                        torch.save(optimizer.state_dict(),
                                   os.path.join(output_dir, "optimizer.pt"))
                        torch.save(scheduler.state_dict(),
                                   os.path.join(output_dir, "scheduler.pt"))
                        logger.info(
                            "Saving optimizer and scheduler states to {}".
                            format(output_dir))

            if args.max_steps > 0 and global_step > args.max_steps:
                break

        mb.write("Epoch {} done".format(epoch + 1))

        if args.max_steps > 0 and global_step > args.max_steps:
            break

    return global_step, tr_loss / global_step
Example #29
0
    def fit(
        self,
        epochs,
        lr,
        validate=True,
        schedule_type="warmup_cosine",
        optimizer_type="lamb",
    ):

        tensorboard_dir = self.output_dir / "tensorboard"
        tensorboard_dir.mkdir(exist_ok=True)

        # Train the model
        tb_writer = SummaryWriter(tensorboard_dir)

        train_dataloader = self.data.train_dl
        if self.max_steps > 0:
            t_total = self.max_steps
            self.epochs = (self.max_steps // len(train_dataloader) //
                           self.grad_accumulation_steps + 1)
        else:
            t_total = len(
                train_dataloader) // self.grad_accumulation_steps * epochs

        # Prepare optimiser and schedule
        optimizer = self.get_optimizer(lr, optimizer_type=optimizer_type)

        # get the base model if its already wrapped around DataParallel
        if hasattr(self.model, "module"):
            self.model = self.model.module

        if self.is_fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex to use fp16 training")
            self.model, optimizer = amp.initialize(
                self.model, optimizer, opt_level=self.fp16_opt_level)

        # Get scheduler
        scheduler = self.get_scheduler(optimizer,
                                       t_total=t_total,
                                       schedule_type=schedule_type)

        # Parallelize the model architecture
        if self.multi_gpu is True:
            self.model = torch.nn.DataParallel(self.model)

        # Start Training
        self.logger.info("***** Running training *****")
        self.logger.info("  Num examples = %d", len(train_dataloader.dataset))
        self.logger.info("  Num Epochs = %d", epochs)
        self.logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            self.data.train_batch_size * self.grad_accumulation_steps,
        )
        self.logger.info("  Gradient Accumulation steps = %d",
                         self.grad_accumulation_steps)
        self.logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epoch_step = 0
        tr_loss, logging_loss, epoch_loss = 0.0, 0.0, 0.0
        self.model.zero_grad()
        pbar = master_bar(range(epochs))

        for epoch in pbar:
            epoch_step = 0
            epoch_loss = 0.0
            for step, batch in enumerate(
                    progress_bar(train_dataloader, parent=pbar)):

                inputs, labels = self.data.mask_tokens(batch)
                cpu_device = torch.device("cpu")

                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                self.model.train()

                outputs = self.model(inputs, masked_lm_labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if self.n_gpu > 1:
                    loss = (
                        loss.mean()
                    )  # mean() to average on multi-gpu parallel training
                if self.grad_accumulation_steps > 1:
                    loss = loss / self.grad_accumulation_steps

                if self.is_fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), self.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_grad_norm)

                tr_loss += loss.item()
                epoch_loss += loss.item()

                batch.to(cpu_device)
                inputs.to(cpu_device)
                labels.to(cpu_device)
                torch.cuda.empty_cache()

                if (step + 1) % self.grad_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()

                    self.model.zero_grad()
                    global_step += 1
                    epoch_step += 1

                    if self.logging_steps > 0 and global_step % self.logging_steps == 0:
                        if validate:
                            # evaluate model
                            results = self.validate()
                            for key, value in results.items():
                                tb_writer.add_scalar("eval_{}".format(key),
                                                     value, global_step)
                                self.logger.info(
                                    "eval_{} after step {}: {}: ".format(
                                        key, global_step, value))

                        # Log metrics
                        self.logger.info("lr after step {}: {}".format(
                            global_step,
                            scheduler.get_lr()[0]))
                        self.logger.info("train_loss after step {}: {}".format(
                            global_step,
                            (tr_loss - logging_loss) / self.logging_steps,
                        ))
                        tb_writer.add_scalar("lr",
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar(
                            "loss",
                            (tr_loss - logging_loss) / self.logging_steps,
                            global_step,
                        )

                        logging_loss = tr_loss

            # Evaluate the model after every epoch
            if validate:
                results = self.validate()
                for key, value in results.items():
                    self.logger.info("eval_{} after epoch {}: {}: ".format(
                        key, (epoch + 1), value))

            # Log metrics
            self.logger.info("lr after epoch {}: {}".format(
                (epoch + 1),
                scheduler.get_lr()[0]))
            self.logger.info("train_loss after epoch {}: {}".format(
                (epoch + 1), epoch_loss / epoch_step))
            self.logger.info("\n")

        tb_writer.close()
        return global_step, tr_loss / global_step
    def train(self,
              model,
              epochs_num=1,
              train_dataset=None,
              validation_dataset=None,
              data_collator=None,
              parent_information=None,
              lr=0.01,
              batch_size=64,
              weight_decay=0.01,
              betas=(0.9, 0.999),
              evaluate_steps=40,
              has_parent=True,
              verbose=False):
        '''
        Train the model given with the dataset provided. Will run evaluation on the validation set every
        `evaluate_steps` training steps, and at the end of each epoch.

        Args:
          model: instantiated model to train
          epochs_num: Number of epochs to train
          train_dataset: Train dataset
          validation_dataset: Validation dataset
          data_collator: A data collator function that when called will collate the data, passed to Dataloader
          parent_information:
          lr: Learning rate to use in the Opimizer
          batch_size: Batch size to use
          weight_decay: Optimizer wieght decay
          betas: Betas used in the Optimizer
          evaluate_steps: How many training steps
          verbose: If true the training loss and addition f1 scores will be printed every step
        Returns:
          f1: double, the resulting mean f1 score of all the labels (it will be a number between 0 and 1)
          precision: double, the resulting mean precision of all the labels (it will be a number between 0 and 1)
          recall:
        '''

        self.model = model

        # Prints additional loss and metrics information during training if set to true
        self.verbose = verbose

        # Set timers
        start = time.time()
        remaining_time = 0

        # Get dataloader
        self.data_collator = data_collator
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      collate_fn=self.data_collator,
                                      shuffle=True)
        # Default optimizer
        optimizer = AdamW(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay,
                          betas=betas)

        mb = master_bar(range(epochs_num))
        pb = progress_bar(train_dataloader, parent=mb)
        for epoch in mb:
            for i_batch, sample_batched in enumerate(pb):
                self.model.train()

                # Get input
                x = sample_batched[0].to(self.device)

                #if i_batch == 0:
                #print()
                #print(x.size())

                # Get targets (labels)
                target = sample_batched[1].float().to(self.device)

                if has_parent and (len(sample_batched) == 3):
                    parent_labels = sample_batched[2].float().to(self.device)

                    if self.device == 'cuda':
                        self.model.cuda(0)
                        x = x.cuda(0)
                        parent_labels = parent_labels.cuda(0)
                        target = target.cuda(0)
                    else:
                        self.model.cpu()
                        x = x.cpu()
                        parent_labels = parent_labels.cpu()
                        target = target.cpu()

                    # Pass input to model
                    output = self.model(x, parent_labels)
                else:
                    if self.device == 'cuda':
                        self.model.cuda(0)
                        x = x.cuda(0)
                        target = target.cuda(0)
                    else:
                        self.model.cpu()
                        x = x.cpu()
                        target = target.cpu()

                    # Pass input to model
                    output = self.model(x)

                # Loss
                train_loss = self.criterion(output, target)

                if self.verbose:
                    print(f'train_loss: {train_loss}')

                # Do backward, do step and zero gradients
                train_loss.backward()
                optimizer.step()
                model.zero_grad()
                optimizer.zero_grad()

                # Evaluate
                if (i_batch > 0) and (i_batch % evaluate_steps) == 0:
                    #print('\nevaluating...')
                    _ = self.evaluate(self.model, validation_dataset)

                self.train_losses.append(train_loss.item())

            # Run evaluation at the end of each epoch and return validation outputs
            #print('\nEnd of epoch evaluation results:')
            validation_outputs = self.evaluate(model, validation_dataset)
            y_hat_validation, validation_labels_child, validation_labels_parent = validation_outputs

            # Print out progress stats
            end = time.time()
            remaining_time = remaining_time * 0.90 + (
                (end - start) * (epochs_num - epoch + 1) / (epoch + 1)) * 0.1
            remaining_time_corrected = remaining_time / (1 -
                                                         (0.9**(epoch + 1)))
            epoch_str = "last epoch finished: " + str(epoch + 1)
            progress_str = "progress: " + str(
                (epoch + 1) * 100 / epochs_num) + "%"
            time_str = "time: " + str(remaining_time_corrected / 60) + " mins"
            sys.stdout.write("\r" + epoch_str + " -- " + progress_str +
                             " -- " + time_str)
            sys.stdout.flush()

            self.epochs.append(epoch)

        print("\n" + "Training completed. Total training time: " +
              str(round((end - start) / 60, 2)) + " mins")
        return (y_hat_validation, validation_labels_child,
                validation_labels_parent, self.train_losses,
                self.validation_losses, self.f1_scores_validations,
                self.precisions_validations, self.recalls_validations)