def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch FrontNet') args = Parse(parser) torch.manual_seed(args.seed) # [NeMO] Setup of console logging. logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", filename="log.txt", filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) train_loader, validation_loader, test_loader = LoadData(args) # [NeMO] Loading of the JSON regime file. regime = {} if args.regime is None: print("ERROR!!! Missing regime JSON.") raise Exception else: with open(args.regime, "r") as f: rr = json.load(f) for k in rr.keys(): try: regime[int(k)] = rr[k] except ValueError: regime[k] = rr[k] if args.gray is not None: model = Dronet(PreActBlock, [1, 1, 1], True) else: model = Dronet(PreActBlock, [1, 1, 1], False) # [NeMO] This used to preload the model with pretrained weights. if args.load_model is not None: ModelManager.Read(args.load_model, model) trainer = ModelTrainer(model, args, regime) if args.quantize: #logging.disable(logging.INFO) trainer.Quantize(validation_loader) #trainer.Train(train_loader, validation_loader) #trainer.Predict(test_loader) if args.save_model is not None: ModelManager.Write(trainer.GetModel(), 100, args.save_model)
def main(): # Training settings parser = argparse.ArgumentParser(description='PyTorch FrontNet') args = Parse(parser) torch.manual_seed(args.seed) # [NeMO] Setup of console logging. logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", filename="log.txt", filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) train_loader, validation_loader = LoadData(args) # [NeMO] Loading of the JSON regime file. regime = {} if args.regime is None: print("ERROR!!! Missing regime JSON.") raise Exception else: with open(args.regime, "r") as f: rr = json.load(f) for k in rr.keys(): try: regime[int(k)] = rr[k] except ValueError: regime[k] = rr[k] model = HannaNet(ConvBlock, [1, 1, 1], True) h = 96 w = 160 if args.trainq: epoch = ModelManager.Read(args.load_model, model) trainer = ModelTrainer(model, args, regime, teacher=True) trainer.TrainQuantized(train_loader, validation_loader, h, w, args.epochs) if args.quantize and not args.trainq: model = nemo.transform.quantize_pact(model, dummy_input=torch.ones( (1, 1, h, w))) logging.info("[ETHQ2] Model: %s", model) epoch, prec_dict = ModelManager.ReadQ(args.load_model, model) trainer = ModelTrainer(model, args, regime) trainer.Deploy(validation_loader, h, w, prec_dict) if args.save_model is not None: # torch.save(trainer.model.state_dict(), args.save_model) ModelManager.Write(trainer.GetModel(), 100, args.save_model) print(model)
def save_checkpoint(self, val_loss, model, epoch, file_name): '''Saves model when validation loss decrease.''' if self.verbose: logging.info("[EarlyStopping] Validation loss decreased {} --> {}. Saving model as {}".format(self.val_loss_min, val_loss, file_name)) ModelManager.Write(model, epoch, file_name) self.val_loss_min = val_loss