示例#1
0
    # dataset
    train_dataset, train_count = generate_dataset()

    net = YOLOV3(out_channels=3 * (CATEGORY_NUM + 5))
    print_model_summary(network=net)

    if load_weights_before_training:
        net.load_weights(filepath=save_model_dir +
                         "epoch-{}".format(load_weights_from_epoch))
        print("Successfully load weights!")
    else:
        load_weights_from_epoch = -1

    # loss and optimizer
    yolo_loss = YoloLoss()
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=0.001,
        decay_steps=5000,
        decay_rate=0.96,
        staircase=True)
    optimizer = tf.optimizers.SGD(learning_rate=lr_schedule, momentum=0.8)

    # metrics
    loss_metric = tf.metrics.Mean()

    def train_step(image_batch, label_batch):
        with tf.GradientTape() as tape:
            yolo_output = net(image_batch, training=True)
            loss = yolo_loss(y_true=label_batch, y_pred=yolo_output)
        gradients = tape.gradient(loss, net.trainable_variables)
    check_point_encoder = torch.load(
        './models/yolo_lstm_flow_832_encoder_pretrained.tar',
        map_location=opt.device())
    check_point_decoder = torch.load(
        './models/yolo_lstm_flow_832_decoder_pretrained.tar',
        map_location=opt.device())

    opt.model.load_state_dict(check_point_encoder['model_state_dict'],
                              strict=True)
    opt.decoder.load_state_dict(check_point_decoder['model_state_dict'],
                                strict=True)

    # all parameters are on the device the hidden states are not parameters
    opt.model.to(opt.device())
    opt.decoder.to(opt.device())
    opt.yolo_loss = YoloLoss(opt.model.anchors, filter_fkt=filter_non_zero_gt)
    loss_params = {
        "grid_shape": (opt.encoding_size, opt.encoding_size),
        "image_shape": (opt.image_size, opt.image_size),
        "path_anchors": "dataset_utils/anchors/anchors5.txt"
    }
    opt.pred_loss = NaiveLoss(loss_params)
    opt.pred_loss.anchors *= 2.0

    opt.optimizer_encoder = torch.optim.Adam(opt.model.parameters(),
                                             lr=opt.learning_rate * 0.1,
                                             betas=(opt.momentum, 0.999),
                                             weight_decay=opt.decay)
    opt.optimizer_decoder = torch.optim.Adam(opt.decoder.parameters(),
                                             lr=1e-4,
                                             betas=(opt.momentum, 0.999),
示例#3
0
def train(model,
          dataset,
          model_dir,
          summary_writer,
          epochs,
          lr,
          conf_thres,
          nms_thres,
          iou_thres,
          lambda_coord=5,
          lambda_no_obj=0.5,
          gradient_accumulations=2,
          clip_gradients=False,
          limit=None,
          debug=False,
          print_every=10,
          save_every=None,
          log_to_neptune=False):
    if log_to_neptune:
        env_path = Path(os.environ['HOME'], 'workspace/setup-box/neptune.env')
        load_dotenv(dotenv_path=env_path)

        neptune.init('petersiemen/sandbox',
                     api_token=os.getenv("NEPTUNE_API_TOKEN"))

    total = limit if limit is not None else len(dataset)

    logger.info(
        f'Start training on {total} images. Using lr: {lr}, '
        f'lambda_coord: {lambda_coord}, lambda_no_obj: {lambda_no_obj}, '
        f'conf_thres: {conf_thres}, nms_thres:{nms_thres}, iou_thres: {iou_thres}, '
        f'gradient_accumulations: {gradient_accumulations}, '
        f'clip_gradients: {clip_gradients}, lambda_no_obj: {lambda_no_obj}')
    metrics = Metrics()

    model.to(DEVICE)
    model.train()

    optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=lr)
    grid_sizes = model.grid_sizes

    data_loader = DataLoader(dataset,
                             batch_size=dataset.batch_size,
                             shuffle=True,
                             collate_fn=dataset.collate_fn)
    class_names = model.class_names

    for epoch in range(1, epochs + 1):
        for batch_i, (images, ground_truth_boxes,
                      image_paths) in tqdm(enumerate(data_loader),
                                           total=total):
            if len(images) != dataset.batch_size:
                logger.warning(
                    f"Skipping batch {batch_i} because it does not have correct size ({dataset.batch_size})"
                )
                continue

            images = images.to(DEVICE)

            coordinates, class_scores, confidence = model(images)

            obj_mask, noobj_mask, cls_mask, target_coordinates, target_confidence, target_class_scores = build_targets(
                coordinates, class_scores, ground_truth_boxes, grid_sizes)
            yolo_loss = YoloLoss(coordinates,
                                 confidence,
                                 class_scores,
                                 obj_mask,
                                 noobj_mask,
                                 cls_mask,
                                 target_coordinates,
                                 target_confidence,
                                 target_class_scores,
                                 lambda_coord=lambda_coord,
                                 lambda_no_obj=lambda_no_obj)

            class_scores = torch.sigmoid(class_scores)
            prediction = torch.cat(
                (coordinates, confidence.unsqueeze(-1), class_scores), -1)

            detections = non_max_suppression(prediction=prediction,
                                             conf_thres=conf_thres,
                                             nms_thres=nms_thres)

            ground_truth_map_objects = list(
                GroundTruth.from_ground_truths(image_paths,
                                               ground_truth_boxes))
            detection_map_objects = list(
                Detection.from_detections(image_paths, detections))

            metrics.add_detections_for_batch(detection_map_objects,
                                             ground_truth_map_objects,
                                             iou_thres=iou_thres)

            if debug:
                plot_batch(detections, ground_truth_boxes, images, class_names)

            loss = yolo_loss.get()
            # backward pass to calculate the weight gradients
            loss.backward()

            if clip_gradients:
                logger.debug("Clipping gradients with max_norm = 1")
                clip_grad_norm_(model.parameters(), max_norm=1)

            if batch_i % print_every == 0:  # print every print_every +1  batches
                yolo_loss.capture(summary_writer, batch_i, during='train')
                #plot_weights_and_gradients(model, summary_writer, epoch * batch_i)
                log_performance(epoch, epochs, batch_i, total, yolo_loss,
                                metrics, class_names, summary_writer,
                                log_to_neptune)

            # Accumulates gradient before each step
            if batch_i % gradient_accumulations == 0:
                logger.debug(
                    f"Updating weights for batch {batch_i} (gradient_accumulations :{gradient_accumulations})"
                )
                # update the weights
                optimizer.step()
                # zero the parameter (weight) gradients
                optimizer.zero_grad()

            del images
            del ground_truth_boxes

            if limit is not None and batch_i + 1 >= limit:
                logger.info(
                    'Stop here after training {} batches (limit: {})'.format(
                        batch_i, limit))
                log_performance(epoch, epochs, batch_i, total, yolo_loss,
                                metrics, class_names, summary_writer,
                                log_to_neptune)
                save_model(model_dir, model, epoch, batch_i)
                return

            if save_every is not None and batch_i % save_every == 0:
                save_model(model_dir, model, epoch, batch_i)

        # save model after every epoch
        save_model(model_dir, model, epoch, None)