Exemplo n.º 1
0
def run(run_id, train_patients, test_patients, args):
    print('Train patient ids:', train_patients)
    print('Test patient ids:', test_patients)

    if args.data_name == 'SEED':
        input_size = 200
    elif args.data_name == 'DEAP':
        input_size = 128
    elif args.data_name == 'AMIGOS':
        input_size = 128
    else:
        raise ValueError

    model = RelativePosition(input_size=input_size, input_channels=args.input_channel, hidden_channels=16,
                             feature_dim=args.feature_dim, device=args.device)
    model.cuda(args.device)

    train_dataset = RPDataset(data_path=args.data_path, data_name=args.data_name, num_sampling=args.num_sampling,
                              dis=args.dis, patients=train_patients)

    pretrain(model, train_dataset, args.device, run_id, args)

    del train_dataset

    train_dataset = eval(f'{args.data_name}Dataset')(args.data_path, args.num_seq, train_patients,
                                                     label_dim=args.label_dim)

    if args.finetune_mode == 'freeze':
        use_dropout = False
        use_l2_norm = False
        use_final_bn = True
    else:
        use_dropout = True
        use_l2_norm = False
        use_final_bn = False

    classifier = SimpleClassifier(input_size=input_size, input_channels=args.input_channel,
                                  feature_dim=args.feature_dim, num_classes=args.classes,
                                  use_dropout=use_dropout, use_l2_norm=use_l2_norm, use_batch_norm=use_final_bn,
                                  device=args.device)
    classifier.cuda(args.device)

    classifier.load_state_dict(model.state_dict(), strict=False)

    # Evaluation
    del train_dataset
    test_dataset = eval(f'{args.data_name}Dataset')(args.data_path, args.num_seq, test_patients,
                                                    label_dim=args.label_dim)
    print(test_dataset)
    scores, targets = evaluate(classifier, test_dataset, args.device, args)
    performance = get_performance(scores, targets)
    with open(os.path.join(args.save_path, f'statistics_{run_id}.pkl'), 'wb') as f:
        pickle.dump({'performance': performance, 'args': vars(args), 'cmd': sys.argv}, f)
    print(performance)
Exemplo n.º 2
0
def main_worker(run_id, train_patients, test_patients, args):
    print('Train patient ids:', train_patients)
    print('Test patient ids:', test_patients)

    if args.data_name == 'SEED' or args.data_name == 'SEED-IV':
        input_size = 200
    elif args.data_name == 'DEAP':
        input_size = 128
    elif args.data_name == 'AMIGOS':
        input_size = 128
    elif args.data_name == 'ISRUC':
        input_size = 200
    elif args.data_name == 'SLEEPEDF':
        input_size = 100
    else:
        raise ValueError

    if args.data_name == 'SEED':
        train_dataset = SEEDDataset(args.data_path,
                                    args.num_seq,
                                    train_patients,
                                    label_dim=args.label_dim)
    elif args.data_name == 'SEED-IV':
        train_dataset = SEEDIVDataset(args.data_path,
                                      args.num_seq,
                                      train_patients,
                                      label_dim=args.label_dim)
    elif args.data_name == 'DEAP':
        train_dataset = DEAPDataset(args.data_path,
                                    args.num_seq,
                                    train_patients,
                                    label_dim=args.label_dim)
    elif args.data_name == 'AMIGOS':
        train_dataset = AMIGOSDataset(args.data_path,
                                      args.num_seq,
                                      train_patients,
                                      label_dim=args.label_dim)
    elif args.data_name == 'ISRUC':
        train_dataset = SleepDataset(args.data_path,
                                     'isruc',
                                     args.num_seq,
                                     train_patients,
                                     preprocessing=args.preprocessing)
    elif args.data_name == 'SLEEPEDF':
        train_dataset = SleepDataset(args.data_path,
                                     'sleepedf',
                                     args.num_seq,
                                     train_patients,
                                     preprocessing=args.preprocessing)
    else:
        raise ValueError

    # Finetuning
    if args.finetune_mode == 'freeze':
        use_dropout = False
        use_l2_norm = True
        use_final_bn = True
    else:
        use_dropout = True
        use_l2_norm = False
        use_final_bn = False

    classifier = DCCClassifier(input_size=input_size,
                               input_channels=args.input_channel,
                               feature_dim=args.feature_dim,
                               num_class=args.classes,
                               use_dropout=use_dropout,
                               use_l2_norm=use_l2_norm,
                               use_batch_norm=use_final_bn,
                               device=args.device)
    classifier.cuda(args.device)

    classifier.load_state_dict(torch.load(args.load_path), strict=False)

    print('[INFO] Start fine-tuning...')
    finetune(classifier, train_dataset, args.device, args)

    if args.data_name == 'SEED':
        test_dataset = SEEDDataset(args.data_path,
                                   args.num_seq,
                                   test_patients,
                                   label_dim=args.label_dim)
    elif args.data_name == 'SEED-IV':
        test_dataset = SEEDIVDataset(args.data_path,
                                     args.num_seq,
                                     test_patients,
                                     label_dim=args.label_dim)
    elif args.data_name == 'DEAP':
        test_dataset = DEAPDataset(args.data_path,
                                   args.num_seq,
                                   test_patients,
                                   label_dim=args.label_dim)
    elif args.data_name == 'AMIGOS':
        test_dataset = AMIGOSDataset(args.data_path,
                                     args.num_seq,
                                     test_patients,
                                     label_dim=args.label_dim)
    elif args.data_name == 'ISRUC':
        test_dataset = SleepDataset(args.data_path,
                                    'isruc',
                                    args.num_seq,
                                    test_patients,
                                    preprocessing=args.preprocessing)
    elif args.data_name == 'SLEEPEDF':
        test_dataset = SleepDataset(args.data_path,
                                    'sleepedf',
                                    args.num_seq,
                                    test_patients,
                                    preprocessing=args.preprocessing)
    else:
        raise ValueError

    scores, targets = evaluate(classifier, test_dataset, args.device, args)
    performance = get_performance(scores, targets)
    with open(os.path.join(args.save_path, f'statistics_{run_id}.pkl'),
              'wb') as f:
        pickle.dump({'performance': performance, 'args': vars(args)}, f)
    print(performance)
Exemplo n.º 3
0
def run(gpu, ngpus_per_node, run_id, train_patients, test_patients, args):
    if args.use_dist:
        print(f'[INFO] Process ({gpu}) invoked among {ngpus_per_node} gpus...')

    # Unique random seeds for each thread
    if args.seed is not None:
        setup_seed(args.seed + gpu)

    if args.use_dist:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=ngpus_per_node,
                                rank=gpu)

    if gpu == 0:
        print('Train patient ids:', train_patients)
        print('Test patient ids:', test_patients)

    if args.data_name == 'SEED' or args.data_name == 'SEED-IV':
        input_size = 200
    elif args.data_name == 'DEAP':
        input_size = 128
    elif args.data_name == 'AMIGOS':
        input_size = 128
    else:
        raise ValueError

    if args.use_dist:
        torch.cuda.set_device(gpu)

    if args.feature_mode == 'raw':
        model = DCC((200, args.grid_res, args.grid_res), 1, args.feature_dim, True, 0.07,
                    gpu if args.use_dist else args.device, mode='sst', strides=(1, 2, 2, 2), use_dist=args.use_dist)
    else:
        model = DCC((5, args.grid_res, args.grid_res), 1, args.feature_dim, True, 0.07,
                    gpu if args.use_dist else args.device, mode='sst', strides=(1, 1, 2, 2), use_dist=args.use_dist)
    if args.use_dist:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.cuda(gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
        model_without_ddp = model.module
    else:
        model.cuda(args.device)

    if args.data_name == 'SEED':
        train_dataset = SEEDSSTDataset(args.data_path, args.num_seq, train_patients, label_dim=args.label_dim)
    elif args.data_name == 'SEED-IV':
        train_dataset = SEEDIVSSTDataset(args.data_path, args.num_seq, train_patients, label_dim=args.label_dim)
    elif args.data_name == 'DEAP':
        train_dataset = DEAPSSTDataset(args.data_path, args.num_seq, train_patients, label_dim=args.label_dim)
    else:
        raise ValueError

    pretrain(run_id, model, train_dataset, gpu if args.use_dist else args.device, args)
    if args.use_dist:
        torch.save(model.module.state_dict(),
                   os.path.join(args.save_path, f'dcc_{args.feature_mode}_pretrained.pth.tar'))
    else:
        torch.save(model.state_dict(),
                   os.path.join(args.save_path, f'dcc_{args.feature_mode}_pretrained.pth.tar'))

    if not args.only_pretrain and gpu == 0:
        print('[INFO] Start finetuning on the first gpu...')

        # Finetuning
        if args.finetune_mode == 'freeze':
            use_dropout = False
            use_l2_norm = True
            use_final_bn = True
        else:
            use_dropout = True
            use_l2_norm = False
            use_final_bn = False

        if args.feature_mode == 'raw':
            classifier = DCCClassifier(input_size=(200, args.grid_res, args.grid_res), input_channels=1,
                                       feature_dim=args.feature_dim,
                                       num_class=args.classes,
                                       use_dropout=use_dropout, use_l2_norm=use_l2_norm, use_batch_norm=use_final_bn,
                                       device=gpu if args.use_dist else args.device, mode='sst', strides=(1, 2, 2, 2))
        else:
            classifier = DCCClassifier(input_size=(5, args.grid_res, args.grid_res), input_channels=1,
                                       feature_dim=args.feature_dim,
                                       num_class=args.classes,
                                       use_dropout=use_dropout, use_l2_norm=use_l2_norm, use_batch_norm=use_final_bn,
                                       device=gpu if args.use_dist else args.device, mode='sst', strides=(1, 1, 2, 2))

        classifier.cuda(gpu)

        if args.use_dist:
            classifier.load_state_dict(model.module.state_dict(), strict=False)
        else:
            classifier.load_state_dict(model.state_dict(), strict=False)

        finetune(classifier, train_dataset, gpu if args.use_dist else args.device, args)

        if args.data_name == 'SEED':
            test_dataset = SEEDSSTDataset(args.data_path, args.num_seq, test_patients, label_dim=args.label_dim)
        elif args.data_name == 'SEED-IV':
            test_dataset = SEEDIVSSTDataset(args.data_path, args.num_seq, test_patients, label_dim=args.label_dim)
        # elif args.data_name == 'DEAP':
        #     test_dataset = DEAPDataset(args.data_path, args.num_seq, test_patients, label_dim=args.label_dim)
        else:
            raise ValueError

        # test_dataset = eval(f'{args.data_name}Dataset')(args.data_path, args.num_seq, test_patients,
        #                                                    label_dim=args.label_dim)
        scores, targets = evaluate(classifier, test_dataset, gpu if args.use_dist else args.device, args)
        performance = get_performance(scores, targets)
        with open(os.path.join(args.save_path, f'statistics_{run_id}.pkl'), 'wb') as f:
            pickle.dump({'performance': performance, 'args': vars(args)}, f)
        print(performance)
Exemplo n.º 4
0
def run(gpu, ngpus_per_node, run_id, train_patients, test_patients, args):
    if args.use_dist:
        print(f'[INFO] Process ({gpu}) invoked among {ngpus_per_node} gpus...')

    # Unique random seeds for each thread
    if args.seed is not None:
        setup_seed(args.seed + gpu)

    if args.use_dist:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=ngpus_per_node,
                                rank=gpu)

    if gpu == 0:
        print('Train patient ids:', train_patients)
        print('Test patient ids:', test_patients)

    if args.data_name == 'SEED' or args.data_name == 'SEED-IV':
        input_size = 200
    elif args.data_name == 'DEAP':
        input_size = 128
    elif args.data_name == 'AMIGOS':
        input_size = 128
    else:
        raise ValueError

    if args.use_dist:
        torch.cuda.set_device(gpu)

    if args.data_name == 'SEED':
        train_dataset_v1 = SEEDSSTDataset(args.data_path_v1,
                                          args.num_seq,
                                          train_patients,
                                          label_dim=args.label_dim)
        train_dataset_v2 = SEEDSSTDataset(args.data_path_v2,
                                          args.num_seq,
                                          train_patients,
                                          label_dim=args.label_dim)
    elif args.data_name == 'SEED-IV':
        train_dataset_v1 = SEEDIVSSTDataset(args.data_path_v1,
                                            args.num_seq,
                                            train_patients,
                                            label_dim=args.label_dim)
        train_dataset_v2 = SEEDIVSSTDataset(args.data_path_v2,
                                            args.num_seq,
                                            train_patients,
                                            label_dim=args.label_dim)
    elif args.data_name == 'DEAP':
        train_dataset_v1 = DEAPSSTDataset(args.data_path_v1,
                                          args.num_seq,
                                          train_patients,
                                          label_dim=args.label_dim)
        train_dataset_v2 = DEAPSSTDataset(args.data_path_v2,
                                          args.num_seq,
                                          train_patients,
                                          label_dim=args.label_dim)
    else:
        raise ValueError

    train_dataset = TwoDataset(train_dataset_v1, train_dataset_v2)

    if args.load_path_v1 is not None and args.load_path_v2 is not None:
        assert os.path.isfile(
            args.load_path_v1), f'Invalid file path {args.load_path_v1}!'
        assert os.path.isfile(
            args.load_path_v2), f'Invalid file path {args.load_path_v2}!'

        state_dict_v1 = torch.load(args.load_path_v1)
        state_dict_v2 = torch.load(args.load_path_v2)
    else:
        print(f'[INFO] Training from scratch...')
        if args.first_view == 'raw':
            model = SSTDIS(input_size_v1=(200, args.grid_res, args.grid_res),
                           input_size_v2=(5, args.grid_res, args.grid_res),
                           input_channels=1,
                           feature_dim=args.feature_dim,
                           use_temperature=False,
                           temperature=1,
                           device=gpu if args.use_dist else args.device,
                           strides=None,
                           first_view='raw')
        else:
            model = SSTDIS(input_size_v1=(5, args.grid_res, args.grid_res),
                           input_size_v2=(200, args.grid_res, args.grid_res),
                           input_channels=1,
                           feature_dim=args.feature_dim,
                           use_temperature=False,
                           temperature=1,
                           device=gpu if args.use_dist else args.device,
                           strides=None,
                           first_view='freq')
        model = model.cuda(args.device)

        warmup(run_id, model, train_dataset, args.device, args)

        state_dict_v1 = model.state_dict()
        state_dict_v2 = copy.deepcopy(state_dict_v1)

        new_state_dict_v1 = {}
        for key, value in state_dict_v1.items():
            if 'encoder_q.' in key:
                key = key.replace('encoder_q.', 'encoder.')
                new_state_dict_v1[key] = value
        state_dict_v1 = new_state_dict_v1
        torch.save(
            state_dict_v1,
            os.path.join(args.save_path,
                         f'dcc_warmup_{args.first_view}.pth.tar'))

        new_state_dict_v2 = {}
        for key, value in state_dict_v2.items():
            if 'encoder_s.' in key:
                key = key.replace('encoder_s.', 'encoder.')
                new_state_dict_v2[key] = value
        state_dict_v2 = new_state_dict_v2
        torch.save(
            state_dict_v2,
            os.path.join(
                args.save_path,
                f"dcc_warmup_{'freq' if args.first_view == 'raw' else 'raw'}.pth.tar"
            ))

    assert args.iteration % 2 == 1

    for it in range(args.iteration):
        reverse = False
        if it % 2 == 1:
            reverse = True

        if reverse:
            print(f'[INFO] Iteration {it + 1}, train the second view...')
        else:
            print(f'[INFO] Iteration {it + 1}, train the first view...')

        if not reverse:
            train_dataset = TwoDataset(train_dataset_v1, train_dataset_v2)
            if args.first_view == 'raw':
                model = SSTMMD(input_size_v1=(200, args.grid_res,
                                              args.grid_res),
                               input_size_v2=(5, args.grid_res, args.grid_res),
                               input_channels=1,
                               feature_dim=args.feature_dim,
                               use_temperature=False,
                               temperature=1,
                               device=gpu if args.use_dist else args.device,
                               strides=None,
                               first_view='raw')
            else:
                model = SSTMMD(input_size_v1=(5, args.grid_res, args.grid_res),
                               input_size_v2=(200, args.grid_res,
                                              args.grid_res),
                               input_channels=1,
                               feature_dim=args.feature_dim,
                               use_temperature=False,
                               temperature=1,
                               device=gpu if args.use_dist else args.device,
                               strides=None,
                               first_view='freq')
        else:
            train_dataset = TwoDataset(train_dataset_v2, train_dataset_v1)
            if args.first_view == 'raw':
                model = SSTMMD(input_size_v1=(5, args.grid_res, args.grid_res),
                               input_size_v2=(200, args.grid_res,
                                              args.grid_res),
                               input_channels=1,
                               feature_dim=args.feature_dim,
                               use_temperature=False,
                               temperature=1,
                               device=gpu if args.use_dist else args.device,
                               strides=None,
                               first_view='freq')
            else:
                model = SSTMMD(input_size_v1=(200, args.grid_res,
                                              args.grid_res),
                               input_size_v2=(5, args.grid_res, args.grid_res),
                               input_channels=1,
                               feature_dim=args.feature_dim,
                               use_temperature=False,
                               temperature=1,
                               device=gpu if args.use_dist else args.device,
                               strides=None,
                               first_view='raw')

        if args.use_dist:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model.cuda(gpu)
            model = torch.nn.parallel.DistributedDataParallel(model,
                                                              device_ids=[gpu])
            model_without_ddp = model.module
        else:
            model.cuda(args.device)

        # Second view as sampler
        new_dict = {}
        new_state_dict_v2 = copy.deepcopy(state_dict_v2)
        for k, v in new_state_dict_v2.items():
            if 'encoder.' in k:
                k = k.replace('encoder.', 'sampler.')
                new_dict[k] = v
        new_state_dict_v2 = new_dict

        # First view as encoder k
        new_state_dict_v1 = copy.deepcopy(state_dict_v1)

        state_dict = {**new_state_dict_v1, **new_state_dict_v2}
        try:
            model.load_state_dict(state_dict, strict=False)
        except Exception as e:
            print(e)
            # print(list(state_dict.keys()))
            exit(-1)

        pretrain(run_id, model, train_dataset,
                 gpu if args.use_dist else args.device, args)

        # Update the state dict
        state_dict_v1 = model.state_dict()
        state_dict_v1, state_dict_v2 = state_dict_v2, state_dict_v1

    if gpu == 0:
        print('[INFO] Start finetuning on the first gpu...')

        # Finetuning
        if args.finetune_mode == 'freeze':
            use_dropout = False
            use_l2_norm = True
            use_final_bn = True
        else:
            use_dropout = True
            use_l2_norm = False
            use_final_bn = False

        if args.first_view == 'raw':
            classifier = SSTClassifier(
                input_size_v1=(200, args.grid_res, args.grid_res),
                input_size_v2=(5, args.grid_res, args.grid_res),
                input_channels=1,
                feature_dim=args.feature_dim,
                num_class=args.classes,
                use_dropout=use_dropout,
                use_l2_norm=use_l2_norm,
                use_batch_norm=use_final_bn,
                device=gpu if args.use_dist else args.device,
                strides=(1, 2, 2, 2),
                first_view=args.first_view)
        else:
            classifier = SSTClassifier(
                input_size_v1=(5, args.grid_res, args.grid_res),
                input_size_v2=(200, args.grid_res, args.grid_res),
                input_channels=1,
                feature_dim=args.feature_dim,
                num_class=args.classes,
                use_dropout=use_dropout,
                use_l2_norm=use_l2_norm,
                use_batch_norm=use_final_bn,
                device=gpu if args.use_dist else args.device,
                strides=(1, 1, 2, 2),
                first_view=args.first_view)

        classifier.cuda(gpu)

        if args.use_dist:
            classifier.load_state_dict(model.module.state_dict(), strict=False)
        else:
            classifier.load_state_dict(model.state_dict(), strict=False)

        finetune(classifier, train_dataset,
                 gpu if args.use_dist else args.device, args)

        del train_dataset
        del train_dataset_v1
        del train_dataset_v2

        if args.data_name == 'SEED':
            test_dataset_v1 = SEEDSSTDataset(args.data_path_v1,
                                             args.num_seq,
                                             test_patients,
                                             label_dim=args.label_dim)
            test_dataset_v2 = SEEDSSTDataset(args.data_path_v2,
                                             args.num_seq,
                                             test_patients,
                                             label_dim=args.label_dim)
        elif args.data_name == 'SEED-IV':
            test_dataset_v1 = SEEDIVSSTDataset(args.data_path_v1,
                                               args.num_seq,
                                               test_patients,
                                               label_dim=args.label_dim)
            test_dataset_v2 = SEEDIVSSTDataset(args.data_path_v2,
                                               args.num_seq,
                                               test_patients,
                                               label_dim=args.label_dim)
        elif args.data_name == 'DEAP':
            test_dataset_v1 = DEAPSSTDataset(args.data_path_v1,
                                             args.num_seq,
                                             test_patients,
                                             label_dim=args.label_dim)
            test_dataset_v2 = DEAPSSTDataset(args.data_path_v2,
                                             args.num_seq,
                                             test_patients,
                                             label_dim=args.label_dim)
        else:
            raise ValueError

        test_dataset = TwoDataset(test_dataset_v1, test_dataset_v2)

        scores, targets = evaluate(classifier, test_dataset,
                                   gpu if args.use_dist else args.device, args)
        performance = get_performance(scores, targets)
        with open(os.path.join(args.save_path, f'statistics_{run_id}.pkl'),
                  'wb') as f:
            pickle.dump({'performance': performance, 'args': vars(args)}, f)
        print(performance)