Beispiel #1
0
def main():
    config = yaml.load(open("./config/config.yaml", "r"),
                       Loader=yaml.FullLoader)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    data_transform = get_simclr_data_transforms(**config['data_transforms'])

    train_dataset = datasets.STL10('/home/thalles/Downloads/',
                                   split='train+unlabeled',
                                   download=True,
                                   transform=MultiViewDataInjector(
                                       [data_transform, data_transform]))

    # online network
    online_network = ResNet18(**config['network']).to(device)
    pretrained_folder = config['network']['fine_tune_from']

    # load pre-trained model if defined
    if pretrained_folder:
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder,
                                              'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(
                os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                map_location=torch.device(torch.device(device)))

            if 'online_network_state_dict' in load_params:
                online_network.load_state_dict(
                    load_params['online_network_state_dict'])
                print("Parameters successfully loaded.")

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    predictor = MLPHead(
        in_channels=online_network.projetion.net[-1].out_features,
        **config['network']['projection_head']).to(device)

    # target encoder
    target_network = ResNet18(**config['network']).to(device)

    optimizer = torch.optim.SGD(
        list(online_network.parameters()) + list(predictor.parameters()),
        **config['optimizer']['params'])

    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          optimizer=optimizer,
                          predictor=predictor,
                          device=device,
                          **config['trainer'])

    trainer.train(train_dataset)
    def __init__(self, dataset, options, *args, **kwargs):
        super(ResNet18, self).__init__()

        if kwargs['name'] == 'resnet18':
            resnet = models.resnet18(pretrained=False)
        elif kwargs['name'] == 'resnet50':
            resnet = models.resnet50(pretrained=False)

        if dataset == "cifar10":
            # smaller kernel size in conv2d and no max pooling according to SimCLR paper (Appendix B.9)
            # https://arxiv.org/pdf/2002.05709.pdf
            self.f = []
            for name, module in resnet.named_children():
                if name == 'conv1':
                    module = nn.Conv2d(3,
                                       64,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1,
                                       bias=False)
                if not isinstance(module, nn.Linear) and not isinstance(
                        module, nn.MaxPool2d):
                    self.f.append(module)
            # encoder
            self.encoder = nn.Sequential(*self.f)
        else:
            self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.feature_dim = resnet.fc.in_features
        self.projetion = MLPHead(in_channels=self.feature_dim,
                                 **kwargs['projection_head'],
                                 options=options)
Beispiel #3
0
    def __init__(self, *args, **kwargs):
        super(ResNet18, self).__init__()
        if kwargs['name'] == 'resnet18':
            resnet = models.resnet18(pretrained=False)
        elif kwargs['name'] == 'resnet50':
            resnet = models.resnet50(pretrained=False)

        self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.projetion = MLPHead(in_channels=resnet.fc.in_features, **kwargs['projection_head'])
    def __init__(self, *args, **kwargs):
        super(MLPmixer, self).__init__()

        self.encoder = MLPMixer(in_channels=3,
                                image_size=96,
                                patch_size=16,
                                num_classes=1000,
                                dim=512,
                                depth=8,
                                token_dim=256,
                                channel_dim=2048)

        self.projetion = MLPHead(in_channels=512, **kwargs['projection_head'])
Beispiel #5
0
    def __init__(self, flag_ova, *args, **kwargs):
        super(Multi_ResNet18, self).__init__()
        if kwargs['name'] == 'resnet18':
            resnet = models.resnet18(pretrained=False)
        elif kwargs['name'] == 'resnet50':
            resnet = models.resnet50(pretrained=False)
        elif kwargs['name'] == 'wideresenet':
            resnet = models.wide_resnet50_2(pretrained=False)

        self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.projetion = MLPHead(in_channels=resnet.fc.in_features, **kwargs['projection_head'])
        
        if flag_ova:
            print("ova is training!")
            self.linear = Distance_1D(out_features=resnet.fc.in_features,
                                   num_classes = 10)
        else:
            self.linear = torch.nn.Linear(resnet.fc.in_features, 10, bias=True)
Beispiel #6
0
 def __init__(self, num_class, normal_channel=True):
     super(get_model, self).__init__()
     in_channel = 3 if normal_channel else 0
     self.normal_channel = normal_channel
     self.sa1 = PointNetSetAbstractionMsg(
         512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,
         [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
     self.sa2 = PointNetSetAbstractionMsg(
         128, [0.2, 0.4, 0.8], [32, 64, 128], 320,
         [[64, 64, 128], [128, 128, 256], [128, 128, 256]])
     self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3,
                                       [256, 512, 1024], True)
     self.fc1 = nn.Linear(1024, 512)
     self.bn1 = nn.BatchNorm1d(512)
     self.drop1 = nn.Dropout(0.4)
     self.fc2 = nn.Linear(512, 256)
     self.bn2 = nn.BatchNorm1d(256)
     self.drop2 = nn.Dropout(0.5)
     #         self.fc3 = nn.Linear(256, 40)
     # self.encoder = torch.nn.Sequential(*list(self.children())[:-1])
     self.projetion = MLPHead(in_channels=256,
                              mlp_hidden_size=512,
                              projection_size=128)
Beispiel #7
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--incr', action='store_true', help='train representation incrementally')
    parser.add_argument('--id', type=str, default='', dest='experiment_id',
                        help='experiment id appended to saved files')
    args = parser.parse_args()

    config = yaml.load(open("./config/config.yaml", "r"), Loader=yaml.FullLoader)
    n_class_epochs = config['other']['n_class_epochs']
    eval_step = config['other']['eval_step']

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    data_transform = get_simclr_data_transforms(**config['data_transforms'])

    #train_dataset = datasets.STL10('../../data', split='train+unlabeled', download=False,
    #                               transform=MultiViewDataInjector([data_transform, data_transform]))

    # online network
    online_network = ResNet18(**config['network']).to(device)
    pretrained_folder = config['network']['fine_tune_from']

    # load pre-trained model if defined
    if pretrained_folder:
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder, 'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                                     map_location=torch.device(torch.device(device)))

            online_network.load_state_dict(load_params['online_network_state_dict'])

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    predictor = MLPHead(in_channels=online_network.projection.net[-1].out_features,
                        **config['network']['projection_head']).to(device)

    # target encoder
    target_network = ResNet18(**config['network']).to(device)

    optimizer = torch.optim.SGD(list(online_network.parameters()) + list(predictor.parameters()),
                                **config['optimizer']['params'])

    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          optimizer=optimizer,
                          predictor=predictor,
                          device=device,
                          **config['trainer'])

    num_workers = config['trainer']['num_workers']
    batch_size_train = 100
    batch_size_test = 200

    if args.incr:
        incr_train_loaders, incr_val_loaders = get_dataloader_incr(data_dir='../../data', base='CIFAR10', num_classes=10,
                                                                   img_size=224, classes_per_exposure=2, train=True,
                                                                   num_workers=num_workers, batch_size_train=batch_size_train,
                                                                   batch_size_test=batch_size_test,
                                                                   transform=MultiViewDataInjector([data_transform, data_transform]))

        # get train and val indices sampled
        train_indices = np.concatenate([ldr.sampler.indices for ldr in incr_train_loaders])

        train_class_dataloader = DataLoader(incr_train_loaders[0].dataset, sampler=SubsetRandomSampler(train_indices),
                                            batch_size=batch_size_train, num_workers=num_workers)
        #trainer.train(train_dataset)
        trainer.train_incr(incr_train_loaders, incr_val_loaders,
                           n_class_epochs=n_class_epochs,
                           train_class_dataloader=train_class_dataloader,
                           experiment_id=args.experiment_id,
                           eval_step=eval_step)
    else:
        train_loader, val_loader = get_dataloader(data_dir='../../data', base='CIFAR10', num_classes=10,
                                                  img_size=224, train=True, num_workers=num_workers,
                                                  batch_size_train=batch_size_train, batch_size_test=batch_size_test,
                                                  transform=MultiViewDataInjector([data_transform, data_transform]))
        trainer.train(train_loader, val_loader, n_class_epochs=n_class_epochs, experiment_id=args.experiment_id,
                      eval_step=eval_step)
Beispiel #8
0
def main():
    config = yaml.load(open("./config/config.yaml", "r"),
                       Loader=yaml.FullLoader)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    data_transform = get_simclr_data_transforms(**config['data_transforms'])
    data_transform2 = get_simclr_data_transforms(**config['data_transforms'],
                                                 blur=1.)

    # data_transform = get_simclr_data_transforms_randAugment(config['data_transforms']['input_shape'])
    # data_transform2 = get_simclr_data_transforms_randAugment(config['data_transforms']['input_shape'])

    train_dataset = datasets.STL10('/media/snowflake/Data/',
                                   split='train+unlabeled',
                                   download=True,
                                   transform=MultiViewDataInjector(
                                       [data_transform, data_transform2]))
    # train_dataset = STL("/home/snowflake/Descargas/STL_data/unlabeled_images", 96,
    #                                  transform=MultiViewDataInjector([data_transform, data_transform2]))

    # online network (the one that is trained)
    online_network = ResNet(**config['network']).to(device)
    # online_network = MLPmixer(**config['network']).to(device)

    # target encoder
    target_network = ResNet(**config['network']).to(device)
    # target_network = MLPmixer(**config['network']).to(device)

    pretrained_folder = config['network']['fine_tune_from']

    # load pre-trained model if defined
    if pretrained_folder:
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder,
                                              'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(
                os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                map_location=torch.device(torch.device(device)))

            online_network.load_state_dict(
                load_params['online_network_state_dict'])
            target_network.load_state_dict(
                load_params['target_network_state_dict'])

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    predictor = MLPHead(
        in_channels=online_network.projetion.net[-1].out_features,
        **config['network']['projection_head']).to(device)

    optimizer = torch.optim.SGD(
        list(online_network.parameters()) + list(predictor.parameters()),
        **config['optimizer']['params'])

    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          optimizer=optimizer,
                          predictor=predictor,
                          device=device,
                          **config['trainer'])

    trainer.train(train_dataset)
Beispiel #9
0
def main(args):
    config = yaml.load(open("./config/config.yaml", "r"),
                       Loader=yaml.FullLoader)

    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('classification')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''DATA LOADING'''
    log_string('Load dataset ...')
    DATA_PATH = 'data/modelnet40_normal_resampled/'

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                       npoint=args.num_point,
                                       split='train',
                                       normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH,
                                      npoint=args.num_point,
                                      split='test',
                                      normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=8)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=8)
    '''MODEL LOADING'''
    num_class = 40
    MODEL = importlib.import_module(args.model)
    shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('./models/pointnet_util.py', str(experiment_dir))

    # online network
    online_network = MODEL.get_model(num_class,
                                     normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()
    # predictor network
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    predictor = MLPHead(
        in_channels=online_network.projetion.net[-1].out_features,
        **config['network']['projection_head']).to(device)

    # target encoder
    target_network = MODEL.get_model(num_class,
                                     normal_channel=args.normal).cuda()
    # load pre-trained model if defined

    try:
        checkpoint = torch.load('checkpoints/model.pth')
        online_network.load_state_dict(checkpoint['online_network_state_dict'])
        target_network.load_state_dict(checkpoint['target_network_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0

#    if args.optimizer == 'Adam':
#        optimizer = torch.optim.Adam(
#            classifier.parameters(),
#            lr=args.learning_rate,
#            betas=(0.9, 0.999),
#            eps=1e-08,
#            weight_decay=args.decay_rate
#        )
#    else:
#        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    optimizer = torch.optim.SGD(
        list(online_network.parameters()) + list(predictor.parameters()),
        **config['optimizer']['params'])
    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          predictor=predictor,
                          optimizer=optimizer,
                          device=device,
                          **config['trainer'])

    trainer.train_pointnet(trainDataLoader, testDataLoader)
Beispiel #10
0
def main():
    config = yaml.load(open("./config/config.yaml", "r"))
    os.environ["CUDA_VISIBLE_DEVICES"] = config['CUDA_VISIBLE_DEVICES']
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    #data_transform = get_simclr_data_transforms(**config['data_transforms'])

    train_source,ft_source,test_source = load_data(config)
    '''
    train_source,_,_ = load_data(config)
    config_test=config
    config_test['dataset_path']="/home/yqliu/Dataset/CASIA-B/silhouettes"
    config_test['dataset']='CASIA-B'
    ft_source,_,test_source = load_data(config_test)
    print(config_test['dataset'])
    '''
    #print(type(train_source),type(test_source))
    print(train_source.__len__(),ft_source.__len__(),test_source.__len__())
    train=True
    test=False
    Train_Set=Test_Set=None
    if train:
        print("Loading training data...")
        #Train_Set = train_source.load_all_data()
    if test:
        print("Loading test data...")
        #Test_Set = test_source.load_all_data()
    #print(type(Train_Set),type(Test_Set))

    # online network
    #online_network = ResNet18(**config['network']).to(device)
    online_network = SetNet(config['network']['hidden_dim'])
    online_network=nn.DataParallel(online_network).cuda()
    pretrained_folder = config['network']['fine_tune_from']
    
    
    conv_trans=MCM_NOTP(in_channels=128, out_channels=128, p=31, div=4)
    conv_trans=nn.DataParallel(conv_trans).cuda()
    #conv_trans=MCM_NOTP(in_channels=config['network']['hidden_dim'], out_channels=config['network']['hidden_dim'], p=16, div=4).to(device)
    TP_1=TP_FULL(hidden_dim=config['network']['hidden_dim'])
    TP_1=nn.DataParallel(TP_1).cuda()
    
    # load pre-trained model if defined
    if pretrained_folder!='None':
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder, 'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                                     map_location=torch.device(torch.device(device)))

            online_network.load_state_dict(load_params['online_network_state_dict'])

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    projection_1 = MLPHead(in_channels=config['network']['hidden_dim'],**config['network']['projection_head'])
    projection_1=nn.DataParallel(projection_1).cuda()
    #projection = MLPHead(in_channels=128,**config['network']['projection_head']).to(device)
    #predictor_1=predictor(hidden_dim=config['network']['hidden_dim']).cuda()
    predictor = MLPHead(in_channels=config['network']['hidden_dim'],**config['network']['projection_head'])
    predictor=nn.DataParallel(predictor).cuda()
    
    # target encoder
    #target_network = ResNet18(**config['network']).to(device)
    target_network = SetNet(config['network']['hidden_dim'])
    target_network=nn.DataParallel(target_network).cuda()
    TP_2=TP_FULL(hidden_dim=config['network']['hidden_dim'])
    TP_2=nn.DataParallel(TP_2).cuda()
    #predictor_2=predictor(hidden_dim=config['network']['hidden_dim']).cuda()
    projection_2 = MLPHead(in_channels=config['network']['hidden_dim'],**config['network']['projection_head'])
    projection_2=nn.DataParallel(projection_2).cuda()
    
    
    optimizer = torch.optim.SGD(list(online_network.parameters()) + list(conv_trans.parameters()) + list(TP_1.parameters()) + list(projection_1.parameters()) + list(predictor.parameters()),**config['optimizer']['params'])
    #print(type(projection_1),type(projection_2))
    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          conv_trans=conv_trans,
                          TP_1=TP_1,
                          TP_2=TP_2,
                          projection_1=projection_1,
                          projection_2=projection_2,
                          predictor=predictor,
                          optimizer=optimizer,
                          device=device,
                          config=config,
                          train_source=train_source,
                          test_source=test_source,
                          **config['trainer'])

    #trainer.train_ft(train_source,maxiter=10000)
    #trainer.finetune(ft_source,maxiter=80000)
    trainer.load_model("/home/yqliu/PJs/BYOL_MCM/PyTorch-BYOL/runs/Oct20_11-37-30_123.pami.group/checkpoints/model_9000.pth")
    trainer.finetune(train_source,maxiter=80000)
Beispiel #11
0
def main(args):
    config = yaml.load(open("./config/config.yaml", "r"),
                       Loader=yaml.FullLoader)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    data_transform = get_simclr_data_transforms(**config['data_transforms'])
    train_dataset_1 = datasets.CIFAR10(root='../STL_model/data',
                                       train=True,
                                       download=True,
                                       transform=MultiViewDataInjector(
                                           [data_transform, data_transform]))

    data_transforms = torchvision.transforms.Compose(
        [transforms.Resize(96), transforms.ToTensor()])
    train_dataset_2 = datasets.CIFAR10(root='../STL_model/data',
                                       train=True,
                                       download=True,
                                       transform=MultiViewDataInjector(
                                           [data_transforms, data_transforms]))

    # online network
    online_network = Multi_ResNet18(args.flag_ova,
                                    **config['network']).to(device)
    pretrained_folder = config['network']['fine_tune_from']

    # load pre-trained model if defined
    if pretrained_folder:
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder,
                                              'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(
                os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                map_location=torch.device(torch.device(device)))

            online_network.load_state_dict(
                load_params['online_network_state_dict'])

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    predictor = MLPHead(
        in_channels=online_network.projetion.net[-1].out_features,
        **config['network']['projection_head']).to(device)

    # target encoder
    target_network = Multi_ResNet18(args.flag_ova,
                                    **config['network']).to(device)

    optimizer = torch.optim.SGD(
        list(online_network.parameters()) + list(predictor.parameters()),
        **config['optimizer']['params'])

    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          optimizer=optimizer,
                          predictor=predictor,
                          device=device,
                          model_path=args.model_path,
                          **config['trainer'])

    trainer.train((train_dataset_1, train_dataset_2), args.flag_ova)
Beispiel #12
0
def main(args):
    log.info("Command line: \n\n" + common_utils.pretty_print_cmd(sys.argv))
    log.info(f"Working dir: {os.getcwd()}")
    log.info("\n" + common_utils.get_git_hash())
    log.info("\n" + common_utils.get_git_diffs())

    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"
    torch.manual_seed(args.seed)
    log.info(args.pretty())

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    log.info(f"Training with: {device}")

    data_transform = get_simclr_data_transforms_train(args['dataset'])
    data_transform_identity = get_simclr_data_transforms_test(args['dataset'])

    if args["dataset"] == "stl10":
        train_dataset = datasets.STL10(args.dataset_path, split='train+unlabeled', download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity]))
    elif args["dataset"] == "cifar10":
        train_dataset = datasets.CIFAR10(args.dataset_path, train=True, download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity]))
    else:
        raise RuntimeError(f"Unknown dataset! {args['dataset']}")

    args = hydra2dict(args)
    train_params = args["trainer"]
    if train_params["projector_same_as_predictor"]:
        train_params["projector_params"] = train_params["predictor_params"]

    # online network
    online_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device)
    if torch.cuda.device_count() > 1:
        online_network = torch.nn.parallel.DataParallel(online_network)

    pretrained_path = args['network']['pretrained_path']
    if pretrained_path:
        try:
            load_params = torch.load(pretrained_path, map_location=torch.device(device))
            online_network.load_state_dict(load_params['online_network_state_dict'])
            online_network.load_state_dict(load_params)
            log.info("Load from {}.".format(pretrained_path))
        except FileNotFoundError:
            log.info("Pre-trained weights not found. Training from scratch.")

    # predictor network
    if train_params["has_predictor"] and args["method"] == "byol":
        predictor = MLPHead(in_channels=args['network']['projection_head']['projection_size'],
                            **args['network']['predictor_head'], options=train_params["predictor_params"]).to(device)
        if torch.cuda.device_count() > 1:
            predictor = torch.nn.parallel.DataParallel(predictor)
    else:
        predictor = None

    # target encoder
    target_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device)
    if torch.cuda.device_count() > 1:
        target_network = torch.nn.parallel.DataParallel(target_network)

    params = online_network.parameters()

    # Save network and parameters.
    torch.save(args, "args.pt")

    if args["eval_after_each_epoch"]: 
        evaluator = Evaluator(args["dataset"], args["dataset_path"], args["test"]["batch_size"]) 
    else:
        evaluator = None

    if args["use_optimizer"] == "adam":
        optimizer = torch.optim.Adam(params, lr=args['optimizer']['params']["lr"], weight_decay=args["optimizer"]["params"]['weight_decay'])
    elif args["use_optimizer"] == "sgd":
        optimizer = torch.optim.SGD(params, **args['optimizer']['params'])
    else:
        raise RuntimeError(f"Unknown optimizer! {args['use_optimizer']}")

    if args["predictor_optimizer_same"]:
        args["predictor_optimizer"] = args["optimizer"]

    if predictor and train_params["train_predictor"]:
       predictor_optimizer = torch.optim.SGD(predictor.parameters(), **args['predictor_optimizer']['params'])

    ## SimCLR scheduler
    if args["method"] == "simclr":
        trainer = SimCLRTrainer(log_dir="./", model=online_network, optimizer=optimizer, evaluator=evaluator, device=device, params=args["trainer"])
    elif args["method"] == "byol":
        trainer = BYOLTrainer(log_dir="./",
                              online_network=online_network,
                              target_network=target_network,
                              optimizer=optimizer,
                              predictor_optimizer=predictor_optimizer,
                              predictor=predictor,
                              device=device,
                              evaluator=evaluator,
                              **args['trainer'])
    else:
        raise RuntimeError(f'Unknown method {args["method"]}')

    trainer.train(train_dataset)

    if not args["eval_after_each_epoch"]:
        result_eval = linear_eval(args["dataset"], args["dataset_path"], args["test"]["batch_size"], ["./"], [])
        torch.save(result_eval, "eval.pth")