예제 #1
0
    def train(self):

        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        if self.augmentor_type == "cnn":
            if self.config["normalization_type"] == "original":
                augmentor = LpAugmentor(
                    clip=self.config["augmentor_clip_output"])
                augmentor.to(self.device)
            elif self.config["normalization_type"] == "spectral":
                augmentor = LpAugmentorSpecNorm(
                    clip=self.config["augmentor_clip_output"])
                augmentor.to(self.device)
            else:
                raise ValueError("Unregonized normalization type: {}".format(
                    self.config["normalization_type"]))
        elif self.augmentor_type == "style_transfer":
            augmentor = LpAugmentorStyleTransfer(
                clip=self.config["augmentor_clip_output"])
            augmentor.to(self.device)
        elif self.augmentor_type == "transformer":
            augmentor = LpAugmentorTransformer(
                clip=self.config["augmentor_clip_output"])
            augmentor.to(self.device)
        else:
            raise ValueError("Unrecognized augmentor type: {}".format(
                self.augmentor_type))

        augmentor_optimizer = torch.optim.Adam(augmentor.parameters(), 3e-4)
        augmentor_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            augmentor_optimizer,
            T_max=len(train_loader),
            eta_min=0,
            last_epoch=-1)

        optimizer = torch.optim.Adam(
            list(model.parameters()),
            3e-4,
            weight_decay=eval(self.config["weight_decay"]),
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

        if apex_support and self.config["fp16_precision"]:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir,
                                                "checkpoints")

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf

        for epoch_counter in range(self.config["epochs"]):
            print("====== Epoch {} =======".format(epoch_counter))
            for (xis, xjs), _ in train_loader:
                optimizer.zero_grad()

                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                loss = self._adv_step(model, augmentor, xis, xjs, n_iter)

                if n_iter % self.config["log_every_n_steps"] == 0:
                    self.writer.add_scalar("train_loss",
                                           loss,
                                           global_step=n_iter)

                if apex_support and self.config["fp16_precision"]:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # for p in augmentor.parameters():
                #     # print(p.name)
                #     p.grad *= -1.0
                optimizer.step()

                # Update augmentor
                augmentor_optimizer.zero_grad()
                loss = self._adv_step(model, augmentor, xis, xjs, n_iter)
                if self.augmentor_loss_type == "hinge":
                    loss = torch.clamp(loss, 0.0, 5.4)
                loss *= -1.0
                loss.backward()
                augmentor_optimizer.step()

                n_iter += 1

            # validate the model if requested
            if epoch_counter % self.config["eval_every_n_epochs"] == 0:
                valid_loss = self._validate(model, augmentor, valid_loader)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(
                        model.state_dict(),
                        os.path.join(model_checkpoints_folder, "model.pth"),
                    )
                print("validation loss: ", valid_loss)
                self.writer.add_scalar("validation_loss",
                                       valid_loss,
                                       global_step=valid_n_iter)
                valid_n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
                augmentor_scheduler.step()

            self.writer.add_scalar("cosine_lr_decay",
                                   scheduler.get_lr()[0],
                                   global_step=n_iter)
예제 #2
0
파일: simclr.py 프로젝트: dy0607/SimCLR
    def train(self):

        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        optimizer = torch.optim.Adam(model.parameters(),
                                     3e-4,
                                     weight_decay=eval(
                                         self.config['weight_decay']))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config['epochs'], eta_min=0, last_epoch=-1)

        if apex_support and self.config['fp16_precision']:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir,
                                                'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf

        for epoch_counter in range(self.config['epochs']):
            for (xis, xjs), _ in train_loader:
                optimizer.zero_grad()

                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                loss = self._step(model, xis, xjs, n_iter)

                if n_iter % self.config['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss',
                                           loss,
                                           global_step=n_iter)

                if apex_support and self.config['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

            # validate the model if requested
            if epoch_counter % self.config['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(
                        model.state_dict(),
                        os.path.join(model_checkpoints_folder, 'model.pth'))

                self.writer.add_scalar('validation_loss',
                                       valid_loss,
                                       global_step=valid_n_iter)
                valid_n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay',
                                   scheduler.get_lr()[0],
                                   global_step=n_iter)
예제 #3
0
    def train(self):

        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        criterion = nn.CrossEntropyLoss()  # loss function

        optimizer = torch.optim.Adam(model.parameters(),
                                     3e-4,
                                     weight_decay=eval(
                                         self.config['weight_decay']))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.config['epochs'], eta_min=0, last_epoch=-1)

        if apex_support and self.config['fp16_precision']:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(
            '/home/zhangchunhui/MedicalAI/USCL/checkpoints_multi_aug',
            'checkpoint_' + str(self.Checkpoint_Num))

        # save config file
        _save_config_file(model_checkpoints_folder)

        start_time = time.time()
        end_time = time.time()
        valid_n_iter = 0
        best_valid_loss = np.inf

        for epoch in range(self.config['epochs']):
            for i, data in enumerate(train_loader, 1):
                # forward
                # mixupimg1, label1, mixupimg2, label2, original img1, original img2
                xis, labelis, xjs, labeljs, imgis, imgjs = data  # N samples of left branch, N samples of right branch

                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                ####### 1-Semi-supervised
                hi, xi, outputis = model(xis)
                hj, xj, outputjs = model(xjs)
                labelindexi, labelindexj = FindNotX(
                    labelis.tolist(), 9999), FindNotX(labeljs.tolist(),
                                                      9999)  # X=9999=no label

                lossi = criterion(outputis[labelindexi],
                                  labelis.to(self.device)[labelindexi])
                lossj = criterion(outputjs[labelindexj],
                                  labeljs.to(self.device)[labelindexj])

                # lumbda1=lumbda2   # small value is better
                lumbda1, lumbda2 = self.lumbda1, self.lumbda2  # small value is better
                loss = self._step(model, xis,
                                  xjs) + lumbda1 * lossi + lumbda2 * lossj
                ########################################################################################################

                ####### 2-Self-supervised
                # loss = self._step(model, xis, xjs)
                ########################################################################################################

                # backward
                optimizer.zero_grad()
                if apex_support and self.config['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # update weights
                optimizer.step()

                if i % self.config['log_every_n_steps'] == 0:
                    # self.writer.add_scalar('train_loss', loss, global_step=i)
                    start_time, end_time = end_time, time.time()
                    print(
                        "\nTraining:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Time: {:.2f}s"
                        .format(epoch + 1, self.config['epochs'], i,
                                len(train_loader), loss,
                                end_time - start_time))

            # validate the model if requested
            if epoch % self.config['eval_every_n_epochs'] == 0:
                start_time = time.time()
                valid_loss = self._validate(model, valid_loader)
                end_time = time.time()
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(
                        model.state_dict(),
                        os.path.join(model_checkpoints_folder,
                                     'best_model.pth'))

                print(
                    "Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Time: {:.2f}s"
                    .format(epoch + 1, self.config['epochs'],
                            len(valid_loader), len(valid_loader), valid_loss,
                            end_time - start_time))
                # self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1

            print('Learning rate this epoch:',
                  scheduler.get_last_lr()[0])  # python >=3.7
            # print('Learning rate this epoch:', scheduler.base_lrs[0])   # python 3.6

            # warmup for the first 10 epochs
            if epoch >= 10:
                scheduler.step()
예제 #4
0
        # assert l_pos.shape == (batch_size, 1), "l_pos shape not valid" + str(l_pos.shape)  # [N,1]

        negatives = torch.cat([zjs, zis], dim=0)

        loss = 0

        for positives in [zis, zjs]:
            l_neg = sim_func_dim2(positives, negatives)

            labels = torch.zeros(batch_size, dtype=torch.long)
            if train_gpu:
                labels = labels.cuda()

            l_neg = l_neg[negative_mask].view(l_neg.shape[0], -1)
            l_neg /= temperature

            # assert l_neg.shape == (batch_size, 2 * (batch_size - 1)), "Shape of negatives not expected." + str(
            #     l_neg.shape)
            logits = torch.cat([l_pos, l_neg], dim=1)  # [N,K+1]
            loss += criterion(logits, labels)

        loss = loss / (2 * batch_size)
        train_writer.add_scalar('loss', loss, global_step=n_iter)

        loss.backward()
        optimizer.step()
        n_iter += 1
        # print("Step {}, Loss {}".format(step, loss))

torch.save(model.state_dict(), './checkpoints/checkpoint.pth')
예제 #5
0
    if epoch_counter % config['eval_every_n_epochs'] == 0:

        # validation steps
        with torch.no_grad():
            model.eval()

            valid_loss = 0.0
            for counter, ((xis, xjs), _) in enumerate(valid_loader):

                if train_gpu:
                    xis = xis.cuda()
                    xjs = xjs.cuda()
                loss = (step(xis, xjs))
                valid_loss += loss.item()

            valid_loss /= counter

            if valid_loss < best_valid_loss:
                # save the model weights
                best_valid_loss = valid_loss
                torch.save(model.state_dict(),
                           os.path.join(model_checkpoints_folder, 'model.pth'))

            train_writer.add_scalar('validation_loss',
                                    valid_loss,
                                    global_step=valid_n_iter)
            valid_n_iter += 1

        model.train()
예제 #6
0
    def train(self):

        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ResNetSimCLR(**self.config["model"]).to(self.device) #just a resnet backbone

        model = self._load_pre_trained_weights(model) #checkpoints (shall we  convert TF checkpoint in to Torch and train or train from the scratch since we have f*****g lot?)

        optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)  #learning rate shedulers (let's use as it is)

        if apex_support and self.config['fp16_precision']:
            model, optimizer = amp.initialize(model, optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf


        for epoch_counter in range(self.config['epochs']):  #start training 
            for (x, y) in train_loader:  #dataset
                optimizer.zero_grad()

        

                x = x.to(self.device)  #in SimCLR we calculate the loss with two augmentation versions
                y = y.to(self.device)

                
                loss = self._step(model, x, y)

             

                if n_iter % self.config['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                if apex_support and self.config['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

            # validate the model if requested
            if epoch_counter % self.config['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                print('Epoch:',epoch_counter,' ---',' validation_loss:',valid_loss)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))

                self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
예제 #7
0
    def train(self):
        #Data
        train_loader, valid_loader = self.dataset.get_data_loaders()

        #Model
        model = ResNetSimCLR(**self.config["model"])
        if self.device == 'cuda':
            model = nn.DataParallel(model, device_ids=[i for i in range(self.config['gpu']['gpunum'])])
        #model = model.to(self.device)
        model = model.cuda()
        print(model)
        model = self._load_pre_trained_weights(model)
        
        each_epoch_steps = len(train_loader)
        total_steps = each_epoch_steps * self.config['train']['epochs'] 
        warmup_steps = each_epoch_steps * self.config['train']['warmup_epochs']
        scaled_lr = eval(self.config['train']['lr']) * self.batch_size / 256.

        optimizer = torch.optim.Adam(
                     model.parameters(), 
                     scaled_lr, 
                     weight_decay=eval(self.config['train']['weight_decay']))
       
        '''
        optimizer = LARS(params=model.parameters(),
                     lr=eval(self.config['train']['lr']),
                     momentum=self.config['train']['momentum'],
                     weight_decay=eval(self.config['train']['weight_decay'],
                     eta=0.001,
                     max_epoch=self.config['train']['epochs'])
        '''

        # scheduler during warmup stage
        lambda1 = lambda epoch:epoch*1.0 / int(warmup_steps)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

        if apex_support and self.config['train']['fp16_precision']:
            model, optimizer = amp.initialize(model, optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf
        lr = eval(self.config['train']['lr']) 

        end = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        
        for epoch_counter in range(self.config['train']['epochs']):
            model.train()
            for i, ((xis, xjs), _) in enumerate(train_loader):
                data_time.update(time.time() - end)
                optimizer.zero_grad()

                xis = xis.cuda()
                xjs = xjs.cuda()

                loss = self._step(model, xis, xjs, n_iter)

                #print("Loss: ",loss.data.cpu())
                losses.update(loss.item(), 2 * xis.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                print('Epoch: [{epoch}][{step}/{each_epoch_steps}] Loss {loss.val:.4f} Avg Loss {loss.avg:.4f} DataTime {datatime.val:.4f} BatchTime {batchtime.val:.4f} LR {lr})'.format(epoch=epoch_counter, step=i, each_epoch_steps=each_epoch_steps, loss=losses, datatime=data_time, batchtime=batch_time, lr=lr))

                if n_iter % self.config['train']['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                if apex_support and self.config['train']['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

                #adjust lr
                if n_iter == warmup_steps:
                    # scheduler after warmup stage
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps-warmup_steps, eta_min=0, last_epoch=-1)
                scheduler.step()
                lr = scheduler.get_lr()[0]
                self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
                sys.stdout.flush()

            # validate the model if requested
            if epoch_counter % self.config['train']['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))

                self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1
예제 #8
0
파일: simclr.py 프로젝트: webMan1/SimCLR
    def train(self, callback=lambda m, e, l: None):

        train_loader, valid_loader = self.dataset.get_data_loaders()

        model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=eval(self.config['weight_decay']))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)

        if apex_support and self.config['fp16_precision']:
            model, optimizer = amp.initialize(model, optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)
        else:
            print("No apex_support or config not fp16 precision")

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf

        eval_freq = self.config['eval_every_n_epochs']
        num_epochs = self.config["epochs"]
        
        train_len = len(train_loader)
        valid_len = len(valid_loader)

        loop = tqdm(total=num_epochs * train_len, position=0)

        for epoch_counter in range(num_epochs):
            for it, ((xis, xjs), _) in enumerate(train_loader):
                optimizer.zero_grad()

                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                loss = self._step(model, xis, xjs, n_iter)

                if n_iter % self.config['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                if apex_support and self.config['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

                loop.update(1)
                loop.set_description(f"E {epoch_counter}/{num_epochs}, it: {it}/{train_len}, Loss: {loss.item()}")

            # validate the model if requested
            if epoch_counter % self.config['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                callback(model, epoch_counter, valid_loss)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(model.state_dict(), os.path.join(model_checkpoints_folder,
                                                            f'{self.dataset.name}-model-{epoch_counter}.pth'))

                self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
예제 #9
0
    def train(self):

        train_loader, valid_loader = self.dataset.get_data_loaders()
        print(
            f'The current dataset has {self.dataset.get_train_length()} items')

        model = ResNetSimCLR(**self.config["model"]).to(self.device)
        model = self._load_pre_trained_weights(model)

        if self.device == self.cuda_name and self.config['allow_multiple_gpu']:
            gpu_count = torch.cuda.device_count()
            if gpu_count > 1:
                print(
                    f'There are {gpu_count} GPUs with the current setup, so we will run on all the GPUs'
                )
                model = torch.nn.DataParallel(model)

        optimizer = torch.optim.Adam(model.parameters(),
                                     3e-4,
                                     weight_decay=eval(
                                         self.config['weight_decay']))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

        if apex_support and self.config['fp16_precision']:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir,
                                                'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf

        for epoch_counter in range(self.config['epochs']):
            t1 = time.time()
            for (xis, xjs), _ in train_loader:
                optimizer.zero_grad()

                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                loss = self._step(model, xis, xjs, n_iter)

                if n_iter % self.config['log_every_n_steps'] == 0:
                    print(
                        f"Epoch {epoch_counter}. Loss = {loss}. Time: {time.strftime('%c', time.localtime())}."
                    )
                    self.writer.add_scalar('train_loss',
                                           loss,
                                           global_step=n_iter)

                if apex_support and self.config['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

            # validate the model if requested
            if epoch_counter % self.config['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(
                        model.state_dict(),
                        os.path.join(model_checkpoints_folder, 'model.pth'))
                    time_for_epoch = int(time.time() - t1)
                    print(f"===\n \
                            Epoch {epoch_counter}. Time for previous epoch: {time_for_epoch} seconds. Time to go: {((self.config['epochs'] - epoch_counter)*time_for_epoch)/60} minutes. Validation loss: {valid_loss}. Best valid loss: {best_valid_loss}\
                          \n===")

                self.writer.add_scalar('validation_loss',
                                       valid_loss,
                                       global_step=valid_n_iter)
                valid_n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay',
                                   scheduler.get_lr()[0],
                                   global_step=n_iter)