src_dataset = get_dataset(dataset_name=args.src_dataset,
                          split=args.src_split,
                          img_transform=img_transform,
                          label_transform=label_transform,
                          test=False,
                          input_ch=args.input_ch)

tgt_dataset = get_dataset(dataset_name=args.tgt_dataset,
                          split=args.tgt_split,
                          img_transform=img_transform,
                          label_transform=label_transform,
                          test=False,
                          input_ch=args.input_ch)

train_loader = torch.utils.data.DataLoader(ConcatDataset(
    src_dataset, tgt_dataset),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           pin_memory=True)

weight = get_class_weight_from_file(n_class=args.n_class,
                                    weight_filename=args.loss_weights_file,
                                    add_bg_loss=args.add_bg_loss)
if torch.cuda.is_available():
    model_g_3ch.cuda()
    model_g_1ch.cuda()
    model_f1.cuda()
    model_f2.cuda()
    weight = weight.cuda()

criterion = CrossEntropyLoss2d(
joint_transform = get_joint_transform(crop_size=args.crop_size, rotate_angle=args.rotate_angle) if use_crop else None

img_transform = get_img_transform(img_shape=train_img_shape, normalize_way=args.normalize_way, use_crop=use_crop)

label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=args.n_class, background_id=args.background_id,
                                    use_crop=use_crop)

src_dataset = get_dataset(dataset_name=args.src_dataset, split=args.src_split, img_transform=img_transform,
                          label_transform=label_transform, test=False, input_ch=args.input_ch)

tgt_dataset = get_dataset(dataset_name=args.tgt_dataset, split=args.tgt_split, img_transform=img_transform,
                          label_transform=label_transform, test=False, input_ch=args.input_ch)

train_loader = torch.utils.data.DataLoader(
    ConcatDataset(
        src_dataset,
        tgt_dataset
    ),
    batch_size=args.batch_size, shuffle=True,
    pin_memory=True)

weight = get_class_weight_from_file(n_class=args.n_class, weight_filename=args.loss_weights_file,
                                    add_bg_loss=args.add_bg_loss)

if torch.cuda.is_available():
    model_enc.cuda()
    model_dec.cuda()
    weight = weight.cuda()

model_enc.train()
model_dec.train()
def train_and_test(args: argparse.Namespace):
    param_config = load_yaml(args.param_file, append=False)

    # Select device
    cuda_device = 'cuda:%d' % args.gpu
    device = torch.device(cuda_device if torch.cuda.is_available() else 'cpu')

    # Generic arguments
    num_epochs = param_config.get('general').get(
        'num_epochs') if args.epochs is None else args.epochs
    num_neighbors = param_config.get('general').get('num_neighbors')

    # Load the selected dataset
    selected_dataset = getattr(datasets,
                               param_config.get('dataset').get('class_name'))

    # Initiate datasets and loaders for each modality
    train_inertial, val_inertial, test_inertial = get_train_val_test_datasets(
        selected_dataset, 'inertial', param_config)
    train_sdfdi, val_sdfdi, test_sdfdi = get_train_val_test_datasets(
        selected_dataset, 'sdfdi', param_config)
    if param_config.get('modalities').get('skeleton'):
        train_skeleton, val_skeleton, test_skeleton = get_train_val_test_datasets(
            selected_dataset, 'skeleton', param_config)
    train_datasets = [train_inertial, train_sdfdi]
    val_datasets = [val_inertial, val_sdfdi]
    test_datasets = [test_inertial, test_sdfdi]
    if param_config.get('modalities').get('skeleton'):
        train_datasets.append(train_skeleton)
        val_datasets.append(val_skeleton)
        test_datasets.append(test_skeleton)
    # Prepare concat datasets and loaders
    train_dataset = ConcatDataset(*train_datasets)
    val_dataset = ConcatDataset(*val_datasets)
    test_dataset = ConcatDataset(*test_datasets)
    num_actions = len(train_dataset.datasets[0].actions)
    batch_size = param_config.get('general').get('batch_size')
    shuffle = param_config.get('general').get('shuffle')
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_sampler=BalancedSampler(
            labels=train_dataset.labels,
            n_classes=num_actions,
            n_samples=param_config.get('general').get('num_samples')))
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=shuffle)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             shuffle=shuffle)
    class_names = train_dataset.get_class_names()

    # Load medusa network
    n1_kwargs = param_config.get('modalities').get('inertial').get(
        'model').get('kwargs')
    n2_kwargs = param_config.get('modalities').get('sdfdi').get('model').get(
        'kwargs')
    n3_kwargs = None
    if param_config.get('modalities').get('skeleton'):
        n3_kwargs = param_config.get('modalities').get('skeleton').get(
            'model').get('kwargs')
    mlp_kwargs = param_config.get('general').get('mlp_kwargs')
    if args.out_size:
        n1_kwargs['out_size'] = args.out_size
        n2_kwargs['out_size'] = args.out_size
        if param_config.get('modalities').get('skeleton'):
            n3_kwargs['out_size'] = args.out_size
        mlp_kwargs['out_size'] = args.out_size
        # Also adjust the input of the mlp due to the change in out_size
        mlp_kwargs['input_size'] = 3 * args.out_size
    if args.dr:
        mlp_kwargs['dropout_rate'] = args.dr
    if args.mlp_hidden_size:
        mlp_kwargs['hidden_size'] = args.mlp_hidden_size

    model = Medusa(mlp_kwargs, n1_kwargs, n2_kwargs, n3_kwargs)
    if args.test:
        model.load_state_dict(torch.load(args.saved_state))
    model = model.to(device)

    # Criterion, optimizer
    criterion = param_config.get('general').get('criterion').get('class_name')
    criterion_from = param_config.get('general').get('criterion').get(
        'from_module')
    criterion_kwargs = param_config.get('general').get('criterion').get(
        'kwargs')
    optimizer = param_config.get('general').get('optimizer').get('class_name')
    optimizer_from = param_config.get('general').get('optimizer').get(
        'from_module')
    optimizer_kwargs = param_config.get('general').get('optimizer').get(
        'kwargs')
    if args.margin:
        criterion_kwargs['margin'] = args.margin
    if args.semi_hard is not None:
        criterion_kwargs['semi_hard'] = args.semi_hard
    if args.lr:
        optimizer_kwargs['lr'] = args.lr
    criterion = getattr(importlib.import_module(criterion_from),
                        criterion)(**criterion_kwargs)
    optimizer = getattr(importlib.import_module(optimizer_from),
                        optimizer)(model.parameters(), **optimizer_kwargs)

    if not args.test:
        if args.experiment is None:
            datetime = time.strftime("%Y%m%d_%H%M", time.localtime())
            experiment = '%s_medusa' % datetime
        else:
            experiment = args.experiment
        writer = SummaryWriter('../logs/' + experiment)

        train_losses, val_losses, val_accuracies, train_accuracies = train_triplet_loss(
            model,
            criterion,
            optimizer,
            class_names,
            train_loader,
            val_loader,
            num_epochs,
            device,
            experiment,
            num_neighbors,
            writer,
            verbose=True,
            skip_accuracy=args.skip_accuracy)

        # Save last state of model
        save_model(model, '%s_last_state.pt' % experiment)

    cm, test_acc, test_scores, test_labels = get_predictions_with_knn(
        n_neighbors=num_neighbors,
        train_loader=train_loader,
        test_loader=test_loader,
        model=model,
        device=device)

    cm_image = plot_confusion_matrix(cm=cm,
                                     title='Confusion Matrix- Test Loader',
                                     normalize=False,
                                     save=False,
                                     show_figure=False,
                                     classes=test_dataset.get_class_names())
    if not args.test:
        writer.add_images('ConfusionMatrix/Test',
                          cm_image,
                          dataformats='CHW',
                          global_step=num_epochs - 1)
        writer.add_embedding(
            test_scores,
            metadata=[class_names[idx] for idx in test_labels.int().tolist()],
            tag="test (%f%%)" % test_acc)
        writer.add_text('config', json.dumps(param_config, indent=2))
        writer.add_text('args', json.dumps(args.__dict__, indent=2))
        writer.flush()
        writer.close()

    if args.print_tsne or args.save_tsne:
        train_scores, train_labels = get_predictions(train_loader,
                                                     model,
                                                     device,
                                                     apply_softmax=False)
        if device.type == 'cuda':
            train_scores = train_scores.cpu()
            train_labels = train_labels.cpu()
        run_tsne(train_scores,
                 train_labels.argmax(1),
                 class_names,
                 filename='train_medusa_embeddings.png',
                 save=args.save_tsne,
                 show=args.print_tsne)
        run_tsne(test_scores,
                 test_labels,
                 class_names,
                 filename='test_medusa_embeddings.png',
                 save=args.save_tsne,
                 show=args.print_tsne)
    print('Test acc: %.5f' % test_acc)

    return test_acc
예제 #4
0
파일: solver.py 프로젝트: kukuruza/MCD_DA
    def __init__(self,
                 args,
                 batch_size=64,
                 source='svhn',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        if self.source == 'citycam' or self.target == 'citycam':
            import sys, os
            sys.path.append(
                os.path.join(os.path.dirname(__file__), '..', 'segmentation'))
            from transform import ReLabel, ToLabel, Scale, RandomSizedCrop, RandomHorizontalFlip, RandomRotation
            from PIL import Image
            from torchvision.transforms import Compose, Normalize, ToTensor
            from datasets import ConcatDataset, get_dataset, check_src_tgt_ok
            from models.model_util import get_models, get_optimizer

            train_img_shape = (
                64, 64)  #  tuple([int(x) for x in args.train_img_shape])
            img_transform_list = [
                Scale(train_img_shape, Image.BILINEAR),
                ToTensor(),
                Normalize([.485, .456, .406], [.229, .224, .225])
            ]
            #            if args.augment:
            #                aug_list = [
            #                    RandomRotation(),
            #                    RandomHorizontalFlip(),
            #                    RandomSizedCrop()
            #                ]
            #                img_transform_list = aug_list + img_transform_list

            img_transform = Compose(img_transform_list)

            label_transform = Compose([
                Scale(train_img_shape, Image.NEAREST),
                ToLabel(),
                ReLabel(
                    255, 12
                )  # args.n_class - 1),  # Last Class is "Void" or "Background" class
            ])

            src_dataset = get_dataset(dataset_name='citycam',
                                      split='synthetic-Sept19',
                                      img_transform=img_transform,
                                      label_transform=label_transform,
                                      test=False,
                                      input_ch=3,
                                      keys_dict={
                                          'image': 'S_image',
                                          'yaw': 'S_label'
                                      })

            tgt_dataset = get_dataset(dataset_name='citycam',
                                      split='real-Sept23-train',
                                      img_transform=img_transform,
                                      label_transform=label_transform,
                                      test=False,
                                      input_ch=3,
                                      keys_dict={'image': 'T_image'})

            self.datasets = torch.utils.data.DataLoader(
                ConcatDataset([src_dataset, tgt_dataset]),
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=True)

            dataset_test = get_dataset(
                dataset_name='citycam',
                split=
                'real-Sept23-test, objectid IN (SELECT objectid FROM properties WHERE key="yaw")',
                img_transform=img_transform,
                label_transform=label_transform,
                test=False,
                input_ch=3,
                keys_dict={
                    'image': 'T_image',
                    'yaw': 'T_label',
                    'yaw_raw': 'T_label_deg'
                })

            self.dataset_test = torch.utils.data.DataLoader(
                dataset_test,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=True)

        else:
            from datasets_dir.dataset_read import dataset_read
            self.datasets, self.dataset_test = dataset_read(
                source,
                target,
                self.batch_size,
                scale=self.scale,
                all_use=self.all_use)

        self.G = Generator(source=source, target=target)
        print('load finished!')
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        if torch.cuda.is_available():
            self.G.cuda()
            self.C1.cuda()
            self.C2.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
예제 #5
0
                              input_ch=args.input_ch,
                              keys_dict={
                                  'image': 'T_image',
                                  'mask': 'T_mask',
                                  'index': 'T_index'
                              })
else:
    tgt_dataset = get_dataset(dataset_name=args.tgt_dataset,
                              split=args.tgt_split,
                              img_transform=img_transform,
                              label_transform=label_transform,
                              test=False,
                              input_ch=args.input_ch,
                              keys_dict={'image': 'T_image'})

concat_dataset = ConcatDataset([src_dataset, tgt_dataset])
train_loader = torch.utils.data.DataLoader(concat_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           pin_memory=True)

model_g, model_f1, model_f2 = get_models(
    net_name=args.net,
    res=args.res,
    input_ch=args.input_ch,
    n_class=args.n_class,
    is_data_parallel=args.is_data_parallel,
    yaw_loss=args.yaw_loss)

optimizer_g = get_optimizer(model_g.parameters(),
                            lr=args.lr,
예제 #6
0
])

source_dataset = get_dataset(dataset_name='source',
                             img_lists=args.source_list,
                             label_lists=args.source_label_list,
                             img_transform=img_transform,
                             label_transform=label_transform,
                             test=False)
target_dataset = get_dataset(dataset_name='target',
                             img_lists=args.target_list,
                             label_lists=None,
                             img_transform=img_transform,
                             label_transform=None,
                             test=False)

train_loader = torch.utils.data.DataLoader(ConcatDataset(
    source_dataset, target_dataset),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           pin_memory=True)

# start training
# background weight: 1  shoe weight: 1
class_weighted = torch.Tensor([args.b_weight, args.s_weight])
class_weighted = class_weighted.cuda()
criterion_c = CrossEntropyLoss2d(class_weighted)
criterion_d = DiscrepancyLoss2d()

G.cuda()
F1.cuda()
F2.cuda()
G.train()