Esempio n. 1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch FrontNet')
    args = Parse(parser)

    torch.manual_seed(args.seed)

    # [NeMO] Setup of console logging.
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s - %(levelname)s - %(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S",
                        filename="log.txt",
                        filemode='w')


    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    train_loader, validation_loader, test_loader = LoadData(args)

    # [NeMO] Loading of the JSON regime file.
    regime = {}
    if args.regime is None:
        print("ERROR!!! Missing regime JSON.")
        raise Exception
    else:
        with open(args.regime, "r") as f:
            rr = json.load(f)
        for k in rr.keys():
            try:
                regime[int(k)] = rr[k]
            except ValueError:
                regime[k] = rr[k]

    if args.gray is not None:
        model = Dronet(PreActBlock, [1, 1, 1], True)
    else:
        model = Dronet(PreActBlock, [1, 1, 1], False)

    # [NeMO] This used to preload the model with pretrained weights.
    if args.load_model is not None:
        ModelManager.Read(args.load_model, model)

    trainer = ModelTrainer(model, args, regime)
    if args.quantize:
        #logging.disable(logging.INFO)
        trainer.Quantize(validation_loader)


    #trainer.Train(train_loader, validation_loader)
    #trainer.Predict(test_loader)

    if args.save_model is not None:
        ModelManager.Write(trainer.GetModel(), 100, args.save_model)
Esempio n. 2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch FrontNet')
    args = Parse(parser)

    torch.manual_seed(args.seed)

    # [NeMO] Setup of console logging.
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s - %(levelname)s - %(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S",
                        filename="log.txt",
                        filemode='w')

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    train_loader, validation_loader = LoadData(args)

    # [NeMO] Loading of the JSON regime file.
    regime = {}
    if args.regime is None:
        print("ERROR!!! Missing regime JSON.")
        raise Exception
    else:
        with open(args.regime, "r") as f:
            rr = json.load(f)
        for k in rr.keys():
            try:
                regime[int(k)] = rr[k]
            except ValueError:
                regime[k] = rr[k]

    model = HannaNet(ConvBlock, [1, 1, 1], True)

    h = 96
    w = 160

    if args.trainq:
        epoch = ModelManager.Read(args.load_model, model)
        trainer = ModelTrainer(model, args, regime, teacher=True)
        trainer.TrainQuantized(train_loader, validation_loader, h, w,
                               args.epochs)

    if args.quantize and not args.trainq:
        model = nemo.transform.quantize_pact(model,
                                             dummy_input=torch.ones(
                                                 (1, 1, h, w)))
        logging.info("[ETHQ2] Model: %s", model)
        epoch, prec_dict = ModelManager.ReadQ(args.load_model, model)
        trainer = ModelTrainer(model, args, regime)
        trainer.Deploy(validation_loader, h, w, prec_dict)

    if args.save_model is not None:
        # torch.save(trainer.model.state_dict(), args.save_model)
        ModelManager.Write(trainer.GetModel(), 100, args.save_model)

    print(model)
Esempio n. 3
0
 def save_checkpoint(self, val_loss, model, epoch, file_name):
     '''Saves model when validation loss decrease.'''
     if self.verbose:
         logging.info("[EarlyStopping]  Validation loss decreased {} --> {}. Saving model as {}".format(self.val_loss_min, val_loss, file_name))
     ModelManager.Write(model, epoch, file_name)
     self.val_loss_min = val_loss