Exemplo n.º 1
0
def train(train_path, model_save_path, num_epochs=3, model_load_path=None, patience=2, gamma=0.1, imp_thresh=0.001):
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  if model_load_path:
    model = SSD_Flow()
    state_dict = torch.load(model_load_path)
    model.load_state_dict(state_dict['model_state_dict'])
  else:
    model = SSD_Flow(initialize=True)
  model.freeze(freeze_add=False)
  model.train_setbn(bn_eval=False) # need to set BN to eval mode if using batch size of 1
  model = model.to(device)

  trainable_params = [param for param in model.parameters() if param.requires_grad]
  optimizer = torch.optim.SGD(trainable_params, lr=0.001, momentum=0.9, weight_decay=0.0005)
  if model_load_path: optimizer.load_state_dict(state_dict['optimizer_state_dict'])
  
  start_epoch = state_dict['epoch']+1 if model_load_path else 0
  # prev_best_loss = state_dict['loss'] if model_load_path else float('inf')
  # num_epochs_no_imp = 0
  # imp_this_lr = False

  dboxes = SSD_utils.dboxes300_coco()
  loss_criterion = Loss(dboxes)

  scene_files = sorted(os.listdir(train_path))

  for epoch in range(start_epoch, num_epochs):
    
    epoch_loss = 0

    for scene_num, record_file in enumerate(scene_files):

      try:
        scene_data = Dataset(os.path.join(train_path, record_file), dataset_pb2.CameraName.FRONT, dboxes, augment=False, flow=True)
      except:
        print(record_file, "Corrupted!")
        continue
      scene_dataloader = torch.utils.data.DataLoader(scene_data, batch_size=32, shuffle=False, collate_fn=collate_fn, drop_last=True)
      scene_loss = 0
      
      start = time()
      for frame_num, (_, _, images, bboxes, labels, flows) in enumerate(scene_dataloader):
        if images is None or len(labels) == 1: continue # batch size of 1 will cause error

        images = images.to(device)
        bboxes = bboxes.to(device)
        labels = labels.to(device)
        flows = flows.to(device)
        
        optimizer.zero_grad()
        pred_boxes, pred_scores = model(images, flows)
        loss = loss_criterion(pred_boxes, pred_scores, bboxes.transpose(1,2), labels)
        loss.backward()
        optimizer.step()
        scene_loss += loss.item()
      
      scene_loss /= len(scene_data)
      epoch_loss += scene_loss
      print("Scene", scene_num, "loss:", scene_loss)
      if (scene_num+1)%25==0:
        torch.save({'epoch': epoch, 'scene_num': scene_num, 'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(), 'loss': scene_loss},
                os.path.join(model_save_path, "ep_"+str(epoch)+'_sc_'+str(scene_num)))

    epoch_loss /= len(scene_files)
    print("\nEpoch", epoch, "loss:", epoch_loss, "\n")
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss},
                os.path.join(model_save_path, "ep_"+str(epoch)))
    
    np.random.shuffle(scene_files)
Exemplo n.º 2
0
                else:
                    im = self.transform(im)

                bboxes, labels = self.encoder.encode(bboxes, labels)

                if self.flows is not None:
                    return orig_size, index, im, bboxes, labels, self.flows[
                        index]
                else:
                    return orig_size, index, im, bboxes, labels


# necessary to discard samples without any objects
def collate_fn(batch):
    batch = list(filter(lambda data: data is not None, batch))
    if len(batch) > 0:
        return torch.utils.data.dataloader.default_collate(batch)
    else:
        return None, None, None, None, None


if __name__ == '__main__':
    tf.enable_eager_execution()
    filename = 'data/train/segment-10206293520369375008_2796_800_2816_800_with_camera_labels.tfrecord'
    dboxes = SSD_utils.dboxes300_coco()
    dataset = Dataset(filename, open_dataset.CameraName.FRONT, dboxes)
    for i in range(190):
        if i != 150: continue
        labels = dataset.__getitem__(i)[2]
        #print(torch.sum(labels), labels.shape)
Exemplo n.º 3
0
def inference(model_path,
              eval_path,
              pred_save_path,
              view=dataset_pb2.CameraName.FRONT):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    with torch.no_grad():
        model = SSD_Ablation()
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict['model_state_dict'])
        model.eval()
        model = model.to(device)

        dboxes = SSD_utils.dboxes300_coco()
        decoder = SSD_utils.Encoder(dboxes)
        scene_files = sorted(os.listdir(eval_path))

        pred_objects = metrics_pb2.Objects()

        for scene_num, record_file in enumerate(scene_files):

            scene_data = Dataset(os.path.join(eval_path, record_file),
                                 view,
                                 dboxes,
                                 val=True,
                                 flow=False)
            scene_dataloader = torch.utils.data.DataLoader(
                scene_data,
                batch_size=32,
                shuffle=False,
                collate_fn=collate_fn)

            for batch_num, (orig_size, idxs, images, bboxes,
                            labels) in enumerate(scene_dataloader):
                if images is None:
                    print("None")
                    continue
                images = images.to(device)
                bboxes = bboxes.to(device)
                labels = labels.to(device)
                #flows = flows.to(device)

                pred_boxes, pred_scores = model(images)
                preds = decoder.decode_batch(pred_boxes,
                                             pred_scores,
                                             criteria=0.5,
                                             max_output=100)

                if metrics_type == 'lib':
                    for i, pred in enumerate(preds):
                        create_prediction_file(pred, orig_size[i], record_file,
                                               idxs[i], pred_save_path)
                elif metrics_type == 'native':
                    for i, pred in enumerate(preds):
                        frame = scene_data.get_frame(idxs[i])
                        for j in range(len(pred[0])):
                            pred_objects.objects.append(
                                create_prediction_obj(frame, pred[0][j],
                                                      pred[2][j], pred[1][j],
                                                      view, orig_size[i]))
                    f = open(pred_save_path, 'wb')
                    f.write(pred_objects.SerializeToString())
                    f.close()

            print("Finished scene", scene_num)
Exemplo n.º 4
0
def train(train_path,
          model_save_path,
          num_epochs=3,
          model_load_path=None,
          patience=2,
          gamma=0.1,
          imp_thresh=0.001,
          set_size=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if model_load_path:
        model = SSD_Baseline()
        state_dict = torch.load(model_load_path)
        model.load_state_dict(state_dict['model_state_dict'])
    else:
        model = SSD_Baseline(initialize=True)
    model.freeze()
    model.train_setbn(
        bn_eval=False)  # need to set BN to eval mode if using batch size of 1
    model = model.to(device)

    trainable_params = [
        param for param in model.parameters() if param.requires_grad
    ]
    optimizer = torch.optim.SGD(trainable_params,
                                lr=0.001,
                                momentum=0.9,
                                weight_decay=0.0005)
    if model_load_path:
        optimizer.load_state_dict(state_dict['optimizer_state_dict'])
    for g in optimizer.param_groups:
        print(g['lr'])
    sys.exit()

    start_epoch = state_dict['epoch'] + 1 if model_load_path else 0
    prev_best_loss = state_dict['loss'] if model_load_path else float('inf')
    num_sets_no_imp = 0
    imp_this_lr = False

    dboxes = SSD_utils.dboxes300_coco()
    loss_criterion = Loss(dboxes)

    scene_files = sorted(os.listdir(train_path))

    for epoch in range(start_epoch, num_epochs):

        np.random.shuffle(scene_files)
        epoch_loss = 0
        set_loss = 0

        for scene_num, record_file in enumerate(scene_files):
            print(record_file)
            try:
                scene_data = Dataset(os.path.join(train_path, record_file),
                                     dataset_pb2.CameraName.FRONT,
                                     dboxes,
                                     augment=False)
            except:
                print("Corrupted!")
                continue
            scene_dataloader = torch.utils.data.DataLoader(
                scene_data,
                batch_size=32,
                shuffle=True,
                collate_fn=collate_fn,
                drop_last=True)
            scene_loss = 0

            start = time()
            for frame_num, (_, _, images, bboxes,
                            labels) in enumerate(scene_dataloader):
                if len(labels) == 1:
                    continue  # batch size of 1 will cause error

                images = images.to(device)
                bboxes = bboxes.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                pred_boxes, pred_scores = model(images)
                loss = loss_criterion(pred_boxes, pred_scores,
                                      bboxes.transpose(1, 2), labels)
                loss.backward()
                optimizer.step()
                scene_loss += loss.item()

            scene_loss /= len(scene_data)
            epoch_loss += scene_loss
            set_loss += scene_loss
            print("Scene", scene_num, "loss:", scene_loss)
            if (scene_num + 1) % set_size == 0:
                set_loss /= set_size
                print("\nSet loss", set_loss, '\n')
                torch.save(
                    {
                        'epoch': epoch,
                        'set': scene_num // set_size,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': set_loss
                    },
                    os.path.join(model_save_path,
                                 "set_" + str(scene_num // set_size)))
                # LR stepping and early stopping
                if set_loss > prev_best_loss - imp_thresh:
                    num_sets_no_imp += 1
                    if num_sets_no_imp == patience:
                        if imp_this_lr:
                            for g in optimizer.param_groups:
                                g['lr'] *= gamma
                            imp_this_lr = False
                        else:
                            return
                else:
                    num_sets_no_imp = 0
                    imp_this_lr = True
                    prev_best_loss = set_loss
                set_loss = 0

        epoch_loss /= len(scene_files)
        print("\nEpoch", epoch, "loss:", epoch_loss, "\n")
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss
            }, os.path.join(model_save_path, "epoch_" + str(epoch)))