class Classifier: def __init__(self, args): print("args = {") for k in args: print("\t{} = {}".format(k, args[k])) print("}") self.args = args.copy() if self.args["dataset_path"] != "": self.data_loader = VOCDataLoader(self.args["dataset_path"], num_processes=4, preload=self.args["preload"]) self.classes = read_classes(self.args["class_path"]) self.color_dict = get_color_dict(self.classes, self.args["color_path"]) if self.args["model_save_dir"] != "" and not os.path.exists( self.args["model_save_dir"]): os.makedirs(self.args["model_save_dir"]) if self.args["graph_save_dir"] != "" and not os.path.exists( self.args["graph_save_dir"]): os.makedirs(self.args["graph_save_dir"]) self.yolo = YOLO() def run(self): if self.args["image_detect_path"] != "": self.yolo.detect_image_and_show(self.args["image_detect_path"], self.color_dict, 0) if self.args["video_detect_path"] != "": self.yolo.detect_video_and_show( self.args["video_detect_path"], self.color_dict, ) if not any([ self.args["do_train"], self.args["do_eval"], self.args["do_test"] ]): return None print('-' * 20 + 'Reading data' + '-' * 20, flush=True) data_train = self.data_loader.get_data_train( ) if self.args["do_train"] else [] data_eval = self.data_loader.get_data_eval( ) if self.args["do_eval"] else [] data_test = self.data_loader.get_data_test( ) if self.args["do_test"] else [] print('-' * 20 + 'Preprocessing data' + '-' * 20, flush=True) for data_id in range(len(data_train)): data_train[data_id][0] = self.yolo.preprocess_image( *(data_train[data_id]), cvt_RGB=True)[0] for data_id in range(len(data_eval)): data_eval[data_id][0] = self.yolo.preprocess_image( *(data_eval[data_id]), cvt_RGB=True)[0] for data_id in range(len(data_test)): data_test[data_id][0] = self.yolo.preprocess_image( *(data_test[data_id]), cvt_RGB=True)[0] if self.args["graph_save_dir"] != "": self.yolo.save_graph(self.args["graph_save_dir"]) for epoch in range(self.args["num_epochs"]): if self.args["do_train"]: """Train""" print('-' * 20 + 'Training epoch %d' % epoch + '-' * 20, flush=True) time.sleep(0.5) random.shuffle(data_train) # 打乱训练数据 for start in tqdm(range(0, len(data_train), self.args["train_batch_size"]), desc='Training batch: '): end = min(start + self.args["train_batch_size"], len(data_train)) loss = self.yolo.train(data_train[start:end]) print(loss) """Save current model""" if self.args["model_save_dir"] != "": self.yolo.save( os.path.join( self.args["model_save_dir"], time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) + "_" + str(epoch) + ".pth")) if self.args["do_eval"]: """Evaluate""" print('-' * 20 + 'Evaluating epoch %d' % epoch + '-' * 20, flush=True) time.sleep(0.5) pred_results = [] for start in tqdm(range(0, len(data_eval), self.args["eval_batch_size"]), desc='Evaluating batch: '): end = min(start + self.args["eval_batch_size"], len(data_eval)) pred_results += self.yolo.predict( [data[0] for data in data_eval[start:end]], num_processes=0) mmAP = self.yolo.get_mmAP(data_eval, pred_results) print("mmAP =", mmAP) if not self.args["do_train"]: break if self.args["do_test"]: pass # """Test""" # print('-' * 20 + 'Testing epoch %d' % epoch + '-' * 20, flush=True) # time.sleep(0.1) # m = metrics.Metrics(self.labels) # for start in tqdm(range(0, len(self.data_test), self.args.test_batch_size), # desc='Testing batch: '): # images = [d[0] for d in self.data_test[start:start + self.args.test_batch_size]] # actual_labels = [d[1] for d in self.data_test[start:start + self.args.test_batch_size]] # """forward""" # batch_images = torch.tensor(images, dtype=torch.float32) # outputs = self.model(batch_images) # """update confusion matrix""" # pred_labels = outputs.softmax(1).argmax(1).tolist() # m.update(actual_labels, pred_labels) # """testing""" # print(m.get_accuracy()) if not self.args["do_train"]: break print()
def _main_(args): config_path = args.conf with open(config_path) as config_buffer: config = json.load(config_buffer) ############################### # Parse the annotations ############################### # parse annotations of the training set train_imgs, train_labels = parse_annotation( config['train']['train_annot_folder'], config['train']['train_image_folder'], config['model']['labels']) # parse annotations of the validation set, if any, otherwise split the training set if os.path.exists(config['valid']['valid_annot_folder']): valid_imgs, valid_labels = parse_annotation( config['valid']['valid_annot_folder'], config['valid']['valid_image_folder'], config['model']['labels']) else: train_valid_split = int(0.8 * len(train_imgs)) np.random.shuffle(train_imgs) valid_imgs = train_imgs[train_valid_split:] train_imgs = train_imgs[:train_valid_split] if len(set(config['model']['labels']).intersection(train_labels)) == 0: print "Labels to be detected are not present in the dataset! Please revise the list of labels in the config.json file!" return ############################### # Construct the model ############################### yolo = YOLO(architecture=config['model']['architecture'], input_size=config['model']['input_size'], labels=config['model']['labels'], max_box_per_image=config['model']['max_box_per_image'], anchors=config['model']['anchors']) ############################### # Load the pretrained weights (if any) ############################### if os.path.exists(config['train']['pretrained_weights']): yolo.load_weights(config['train']['pretrained_weights']) ############################### # Start the training process ############################### yolo.train(train_imgs=train_imgs, valid_imgs=valid_imgs, train_times=config['train']['train_times'], valid_times=config['valid']['valid_times'], nb_epoch=config['train']['nb_epoch'], learning_rate=config['train']['learning_rate'], batch_size=config['train']['batch_size'], warmup_bs=config['train']['warmup_batches'], object_scale=config['train']['object_scale'], no_object_scale=config['train']['no_object_scale'], coord_scale=config['train']['coord_scale'], class_scale=config['train']['class_scale'], debug=config['train']['debug'])
start_time = time.ctime() print("### start_time:", start_time) cur_epoch = 0 for epoch in range(cur_epoch, epochs): if epoch == 74 or epoch == 104: print("Changing learning from {} to {}...".format(lr, lr * 0.1)) lr = lr * 0.1 change_lr(optimizer, lr) accumulated_train_loss = [] start_timestamp = time.time() # Train model.train() iteration = 0 for batch_x, batch_y in tqdm(train_generator): # print("batch shape:", batch_x.shape) batch_x = batch_x.to(device) batch_y = batch_y.to(device) # Forward preds = model(batch_x) # compute loss loss = loss_func(preds, batch_y) accumulated_train_loss.append(loss.item()) # zero gradients optimizer.zero_grad()