コード例 #1
0
    def save_model(self, error, name):
        is_best = error < self.best_error

        if is_best:
            self.best_error = error

        models = {'epoch': self.i_epoch,
                  'state_dict': self.model.module.state_dict()}

        save_checkpoint(self.save_root, models, name, is_best)
コード例 #2
0
def train(total_epochs=1, interval=100, resume=False, ckpt_path = ''):
    print("Loading training dataset...")
    train_dset = OpenImagesDataset(root='./data/train',
                            list_file ='./data/tmp/train_images_bbox.csv',
                            transform=transform, train=True, input_size=600)

    train_loader = data.DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=4, collate_fn=train_dset.collate_fn)
    
    print("Loading completed.")

    #val_dset = OpenImagesDataset(root='./data/train',
    #                  list_file='./data/tmp/train_images_bbox.csv', train=False, transform=transform, input_size=600)
    #val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=False, num_workers=4, collate_fn=val_dset.collate_fn)

    net = RetinaNet()
    net.load_state_dict(torch.load('./model/net.pth'))

    criterion = FocalLoss()
    
    net.cuda()
    criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    best_val_loss = 1000

    start_epoch=0

    if resume:
        if os.path.isfile(ckpt_path):
            print(f'Loading from the checkpoint {ckpt_path}')
            checkpoint = torch.load(ckpt_path)
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f'Loaded checkpoint {ckpt_path}, epoch : {start_epoch}')
        else:
            print(f'No check point found at the path {ckpt_path}')

    

    for epoch in range(start_epoch, total_epochs):
        train_one_epoch(train_loader, net, criterion, optimizer, epoch, interval)
        val_loss = 0
        #val_loss = validate(val_loader, net, criterion, interval)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint({
                'epoch': epoch+1,
                'state_dict': net.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer' : optimizer.state_dict()
            }, is_best=True)
コード例 #3
0
    def save_model(self, error, name):
        if self.id > 0:
            return

        is_best = error < self.best_error

        if is_best:
            self.best_error = error

        models = {'epoch': self.i_epoch,
                  'state_dict': self.model.module.state_dict()}

        self._log.info(self.id, "=> Saving Model..")

        save_checkpoint(self.save_root, models, name, is_best)
コード例 #4
0
    def save_model(self, errors, names):
        is_best_depth = errors[0] < self.best_errors[0]
        is_best_flow = errors[1] < self.best_errors[1]

        if is_best_depth:
            self.best_errors[0] = errors[0]

        model = {
            'epoch': self.i_epoch,
            'state_dict': self.model[1].module.state_dict()
        }
        save_checkpoint(self.save_root, model, names[0], is_best_depth)

        if is_best_flow:
            self.best_errors[1] = errors[1]
        model = {
            'epoch': self.i_epoch,
            'state_dict': self.model[0].module.state_dict()
        }
        save_checkpoint(self.save_root, model, names[1], is_best_flow)
コード例 #5
0
def train_one_epoch(train_loader, model, loss_fn, opt, epoch, interval):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    model.train()
    train_loss = 0
    no_of_batches = int(train_loader.dataset.num_samples/train_loader.batch_size) + 1

    end = time.time()

    for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(train_loader):

        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        loc_targets = loc_targets.cuda()
        cls_targets = cls_targets.cuda()

        opt.zero_grad()
        loc_preds, cls_preds = model(inputs)
        loss = loss_fn(loc_preds, loc_targets, cls_preds, cls_targets)
        loss.backward()
        opt.step()

        batch_time.update(time.time() - end)
        end = time.time()

        train_loss += loss.data[0]
        if(batch_idx%interval == 0):
            print(f'Train -> Batch : [{batch_idx}/{no_of_batches}]| Batch avg time :{batch_time.avg} \
            | Data_avg_time: {data_time.avg} | avg_loss: {train_loss/(batch_idx+1)}')
        
        if(batch_idx%(5000)==0):
            save_checkpoint({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_val_loss': train_loss/(batch_idx+1),
                'optimizer' : optimizer.state_dict()
            }, is_best=True, fname=f'checkpoint_{epoch}_{batch_idx}.pth.tar')
コード例 #6
0
ファイル: model.py プロジェクト: kiminh/contener
    def train_loop(self,
                   iterators,
                   optimizer,
                   run_dir,
                   epochs=100,
                   min_epochs=0,
                   patience=5,
                   epoch_start=0,
                   best_f1=None,
                   epochs_no_improv=None,
                   grad_clipping=0,
                   overlap=None,
                   train_entities=None,
                   train_key="train",
                   dev_key="dev",
                   eval_on_train=False,
                   gradient_accumulation=1,
                   **kwargs):

        logging.info(
            "Starting train loop: {} epochs; {} min; {} patience".format(
                epochs, min_epochs, patience))

        if best_f1 is None:
            best_f1 = 0

        if epochs_no_improv is None:
            epochs_no_improv = 0

        if not train_key == "train":
            patience = 0

        if patience and epoch_start > min_epochs and epochs_no_improv >= patience:
            logging.info(
                "Early stopping after {} epochs without improvement.".format(
                    patience))

        else:
            writer = SummaryWriter(run_dir)
            for epoch in range(epoch_start, epochs):
                logging.info("Epoch {}/{} :".format(epoch + 1, epochs))
                train_losses = self.run_epoch(
                    iterators,
                    epoch,
                    optimizer,
                    writer,
                    grad_clipping=grad_clipping,
                    train_key=train_key,
                    gradient_accumulation=gradient_accumulation)
                n_iter = (epoch + 1) * len(list(train_losses))

                if eval_on_train:
                    logging.info("Train eval")
                    self.evaluate(iterators["ner"][train_key])

                _, ner_loss, ner_scores = self.evaluate(
                    iterators[dev_key],
                    overlap=overlap,
                    train_entities=train_entities)

                logging.info("Train NER Loss : {}".format(
                    np.mean(train_losses)))
                logging.info("Dev NER Loss : {}".format(ner_loss))

                if overlap is None:
                    if "ner" in iterators.keys():
                        add_score(writer, ner_scores, n_iter)
                else:
                    if "ner" in iterators.keys():
                        add_score_overlap(writer,
                                          ner_scores,
                                          n_iter,
                                          task="ner")

                f1 = ner_scores["ALL"]["f1"]

                if f1 > best_f1:
                    logging.info(f"New best NER F1 score on dev : {f1}")
                    logging.info("Saving model...")
                    best_f1 = f1
                    epochs_no_improv = 0
                    is_best = True

                else:
                    epochs_no_improv += 1
                    is_best = False

                state = {
                    'epoch': epoch + 1,
                    'epochs_no_improv': epochs_no_improv,
                    'model': self.state_dict(),
                    'scores': ner_scores,
                    'optimizer': optimizer.state_dict()
                }
                save_checkpoint(state,
                                is_best,
                                checkpoint=run_dir + 'ner_checkpoint.pth.tar',
                                best=run_dir + 'ner_best.pth.tar')

                writer.add_scalars("ner_loss", {"dev": ner_loss}, n_iter)

                if patience and epoch > min_epochs and epochs_no_improv >= patience:
                    logging.info(
                        f"Early stopping after {patience} epochs without improvement on NER."
                    )
                    break

            writer.export_scalars_to_json(run_dir + "all_scalars.json")
            writer.close()