def train(path=None, log_path=None): """Train the CNN mode. Args: path (str): checkpoint file path. log_path (str): log_path. default='./log/train_<datetime>.log' """ """ ===== Constant var. start =====""" train_comment = "" num_workers = 7 batch_size = 128 lr = 0.001 lr_decay = 0.9 lr_decay_step = 500 max_epoch = 500 stat_freq = 10 model_name = "0304_tf" """ ===== Constant var. end =====""" # step0: init. log and checkpoint dir. checkpoints_dir = "./checkpoints/" + model_name if len(train_comment) > 0: checkpoints_dir = checkpoints_dir + "_" + train_comment if not os.path.isdir("./checkpoints"): os.mkdir("./checkpoints") if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) if log_path == None: if not os.path.isdir("./log"): os.mkdir("./log") if not os.path.exists("./log/" + model_name): os.makedirs("./log/" + model_name) log_path = "./log/{}".format(model_name) # step1: dataset val_data = Data(train=False, format="NHWC") val_dataloader = DataLoader(val_data, 100, num_workers=num_workers) train_data = Data(train=True, format="NHWC") train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=num_workers, pin_memory=False) writer = tf.summary.create_file_writer(log_path) best_acc_img = 0 # step2: instance and load model inputs = tf.keras.Input(shape=(128, 128, 1)) model = CNN_tf()(inputs) model = tf.keras.Model(inputs=inputs, outputs=model) model.summary() # step3: loss function and optimizer criterion = Loss() scheduler = tf.keras.optimizers.schedules.ExponentialDecay( lr, lr_decay_step, lr_decay) optimizer = tf.keras.optimizers.Adam(scheduler) global_step = tf.Variable(1) checkpoint = tf.train.Checkpoint( optimizer=optimizer, model=model, step=global_step, ) ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoints_dir, max_to_keep=10) previous_loss = 1e100 @tf.function def train_step(inputss, targets): with tf.GradientTape() as tape: pred = model(inputss, training=True) loss = criterion(pred, targets) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) return loss # epoch loop for epoch in range(max_epoch): running_loss = 0.0 total_loss = [] # batch loop pbar = tqdm(enumerate(train_dataloader)) for i, (data, label) in pbar: inputs = data.numpy() target = label.numpy() loss = train_step(inputs, target) running_loss += loss total_loss.append(loss) if (i + 1) % stat_freq == 0: pbar.set_description( "[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / stat_freq)) with writer.as_default(): tf.summary.scalar( "train/loss", running_loss / stat_freq, step=epoch * len(train_dataloader) + i, ) # tf.summary.scalar("train/lr", ) running_loss = 0.0 previous_loss = np.mean(total_loss) acc_img, acc_digit, im_show = val(model, val_dataloader) im_show[0] = cv2.putText(im_show[0], "".join(Data.decode(im_show[1])), (20, 20), 2, 1, (255, 0, 255)) with writer.as_default(): tf.summary.scalar("eval/acc_img", acc_img, epoch * len(train_dataloader)) tf.summary.scalar("eval/acc_digit", acc_digit, epoch * len(train_dataloader)) tf.summary.image("img", im_show[0][np.newaxis, :, :, :], epoch * len(train_dataloader)) if acc_img > best_acc_img: # ckpt_manager.save() model.save(checkpoints_dir, save_format="tf") print("acc_img : {}, acc_digit : {}, loss : {}".format( acc_img, acc_digit, previous_loss)) if np.mean(total_loss) > previous_loss: lr = lr * lr_decay print("reduce loss from to {}".format(lr))
def train(path=None, log_path=None): """Train the CNN mode. Args: path (str): checkpoint file path. log_path (str): log_path. default='./log/train_<datetime>.log' """ """ ===== Constant var. start =====""" train_comment = "" use_gpu = True num_workers = 7 batch_size = 128 lr = 0.001 lr_decay = 0.9 max_epoch = 500 stat_freq = 10 model_name = "test" """ ===== Constant var. end =====""" # step0: init. log and checkpoint dir. checkpoints_dir = "./checkpoints/" + model_name if len(train_comment) > 0: checkpoints_dir = checkpoints_dir + "_" + train_comment if not os.path.isdir("./checkpoints"): os.mkdir("./checkpoints") if not os.path.isdir(checkpoints_dir): os.mkdir(checkpoints_dir) if log_path == None: if not os.path.isdir("./log"): os.mkdir("./log") if not os.path.exists('./log/' + model_name): os.makedirs("./log/" + model_name) log_path = "./log/{}".format(model_name) # step1: dataset val_data = Data(train=False) val_dataloader = DataLoader(val_data, 100, num_workers=num_workers) train_data = Data(train=True) train_dataloader = DataLoader(train_data, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) writer = SummaryWriter(log_path) best_acc_img = 0 with open(log_path + "/log.txt", "w") as log_file: # step2: instance and load model model = CNN() if path != None: print('using mode "{}"'.format(path)) print('using mode "{}"'.format(path), file=log_file, flush=True) model.load(path) else: print("init model by orthogonal_", file=log_file, flush=True) for name, param in model.named_parameters(): if len(param.shape) > 1: torch.nn.init.orthogonal_(param) if use_gpu: model.cuda() # summary(model, (1, 128, 128)) # step3: loss function and optimizer criterion = loss_ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) previous_loss = 1e100 # epoch loop for epoch in range(max_epoch): running_loss = 0.0 total_loss = [] # batch loop pbar = tqdm(enumerate(train_dataloader)) for i, (data, label) in pbar: input = data target = label if use_gpu: input = input.cuda() target = target.cuda() optimizer.zero_grad() score = model(input) loss = criterion(score, target) loss.backward() optimizer.step() running_loss += loss.item() total_loss.append(loss.item()) if (i + 1) % stat_freq == 0: pbar.set_description( "[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / stat_freq)) writer.add_scalar('train/loss', running_loss / stat_freq, epoch * len(train_dataloader) + i) running_loss = 0.0 previous_loss = np.mean(total_loss) acc_img, acc_digit, im_show = val(model, val_dataloader, use_gpu) writer.add_scalar('eval/acc_img', acc_img, epoch * len(train_dataloader)) writer.add_scalar('eval/acc_digit', acc_digit, epoch * len(train_dataloader)) im_show[0] = cv2.putText(im_show[0], ''.join(Data.decode(im_show[1])), (20, 20), 2, 1, (255, 0, 255)) writer.add_image("img", torch.tensor(np.transpose(im_show[0], (2, 0, 1))), epoch * len(train_dataloader)) if acc_img > best_acc_img: save_path = "{}/model.pth".format(checkpoints_dir) model.save(save_path) print("acc_img : {}, acc_digit : {}, loss : {}".format( acc_img, acc_digit, previous_loss)) if np.mean(total_loss) > previous_loss: lr = lr * lr_decay print("reduce loss from to {}".format(lr))