Esempio n. 1
0
    validation_dataset = DataLoader(
        CellInstanceSegmentation(path=os.path.join(args.path_to_data, "val"),
                                 augmentation_p=0.0, two_classes=not args.three_classes),
        collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False)
    test_dataset = DataLoader(
        CellInstanceSegmentation(path=os.path.join(args.path_to_data, "test"),
                                 augmentation_p=0.0, two_classes=not args.three_classes),
        collate_fn=collate_function_cell_instance_segmentation, batch_size=1, num_workers=1, shuffle=False)
    # Model wrapper
    model_wrapper = ModelWrapper(detr=detr,
                                 detr_optimizer=detr_optimizer,
                                 detr_segmentation_optimizer=detr_segmentation_optimizer,
                                 training_dataset=training_dataset,
                                 validation_dataset=validation_dataset,
                                 test_dataset=test_dataset,
                                 loss_function=InstanceSegmentationLoss(
                                     segmentation_loss=SegmentationLoss(),
                                     ohem=args.ohem,
                                     ohem_faction=args.ohem_fraction),
                                 device=device)
    # Perform training
    if args.train:
        model_wrapper.train(epochs=args.epochs,
                            optimize_only_segmentation_head_after_epoch=args.only_train_segmentation_head_after_epoch)
    # Perform validation
    if args.val:
        model_wrapper.validate(number_of_plots=30)
    # Perform testing
    if args.test:
        model_wrapper.test()
        shuffle=True,
        collate_fn=data.image_label_list_of_masks_collate_function)
    validation_dataset_fid = DataLoader(
        data.Places365(path_to_index_file=args.path_to_places365,
                       index_file_name='val.txt',
                       max_length=6000,
                       validation=True),
        batch_size=args.batch_size,
        num_workers=args.batch_size,
        shuffle=False,
        collate_fn=data.image_label_list_of_masks_collate_function)
    validation_dataset = data.Places365(
        path_to_index_file=args.path_to_places365, index_file_name='val.txt')
    # Init model wrapper
    model_wrapper = ModelWrapper(
        generator=generator,
        discriminator=discriminator,
        vgg16=vgg16,
        training_dataset=training_dataset,
        validation_dataset=validation_dataset,
        validation_dataset_fid=validation_dataset_fid,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer)
    # Perform training
    if args.train:
        model_wrapper.train(epochs=args.epochs, device=args.device)
    # Perform testing
    if args.test:
        print('FID=', model_wrapper.validate(device=args.device))
        model_wrapper.inference(device=args.device)