shuffle=True,
        collate_fn=data.image_label_list_of_masks_collate_function)
    validation_dataset_fid = DataLoader(
        data.Places365(path_to_index_file=args.path_to_places365,
                       index_file_name='val.txt',
                       max_length=6000,
                       validation=True),
        batch_size=args.batch_size,
        num_workers=args.batch_size,
        shuffle=False,
        collate_fn=data.image_label_list_of_masks_collate_function)
    validation_dataset = data.Places365(
        path_to_index_file=args.path_to_places365, index_file_name='val.txt')
    # Init model wrapper
    model_wrapper = ModelWrapper(
        generator=generator,
        discriminator=discriminator,
        vgg16=vgg16,
        training_dataset=training_dataset,
        validation_dataset=validation_dataset,
        validation_dataset_fid=validation_dataset_fid,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer)
    # Perform training
    if args.train:
        model_wrapper.train(epochs=args.epochs, device=args.device)
    # Perform testing
    if args.test:
        print('FID=', model_wrapper.validate(device=args.device))
        model_wrapper.inference(device=args.device)
Exemplo n.º 2
0
    # Model wrapper
    model_wrapper = ModelWrapper(detr=detr,
                                 detr_optimizer=detr_optimizer,
                                 detr_segmentation_optimizer=detr_segmentation_optimizer,
                                 training_dataset=training_dataset,
                                 validation_dataset=validation_dataset,
                                 test_dataset=test_dataset,
                                 loss_function=InstanceSegmentationLoss(
                                     segmentation_loss=SegmentationLoss(),
                                     ohem=args.ohem,
                                     ohem_faction=args.ohem_fraction),
                                 device=device)

    # for im, mask,bb,label in training_dataset:
    #    print(im)
    # Perform training
    if args.train:
        model_wrapper.train(epochs=args.epochs,
                            optimize_only_segmentation_head_after_epoch=args.only_train_segmentation_head_after_epoch)
    # Perform validation
    if args.val:
        model_wrapper.validate(number_of_plots=30)

    # Perform testing
    if args.test:
        model_wrapper.test()

        # Perform testing
    if args.inference:
        model_wrapper.inference()