Exemplo n.º 1
0
def main(device, args):
    train_directory = '../data/train'
    image_name_file = '../data/original.csv'
    val_directory = '../data/train'
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset('random', train_directory, image_name_file,
            transform=get_aug(train=True, **args.aug_kwargs),
            train=True,
            **args.dataset_kwargs),
        # dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=True, **args.aug_kwargs)),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    memory_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)
    scaler = torch.cuda.amp.GradScaler()

    # define optimizer
    optimizer = get_optimizer(
        args.train.optimizer.name, model,
        lr=args.train.base_lr * args.train.batch_size / 256,
        momentum=args.train.optimizer.momentum,
        weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256,
                                  args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    RESUME = False
    start_epoch = 0

    if RESUME:
        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim, out_features=9, bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
                                    strict=True)

        path_checkpoint = "./checkpoint/simsiam-TCGA-0218-nearby_0221134812.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch

    logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(start_epoch, args.train.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
        for idx, (images1, images2, images3, labels) in enumerate(local_progress):
            model.zero_grad()
            with torch.cuda.amp.autocast():
                data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True),
                                          images3.to(device, non_blocking=True))
                loss = data_dict['loss'].mean()  # ddp
            # loss.backward()
            scaler.scale(loss).backward()
            # optimizer.step()
            scaler.step(optimizer)
            scaler.update()

            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device,
                                   k=min(args.train.knn_k, len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch
        }
        if (epoch % args.train.save_interval) == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict()
            }, './checkpoint/exp_0223_triple_400_proj3/ckpt_best_%s.pth' % (str(epoch)))

    # Save checkpoint
    model_path = os.path.join(args.ckpt_dir,
                              f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth")  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')


    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
Exemplo n.º 2
0
def main(device, args):
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=True, **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    memory_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=False,
                                              train_classifier=False,
                                              **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              shuffle=False,
                                              batch_size=args.train.batch_size,
                                              **args.dataloader_kwargs)

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    logger = Logger(tensorboard=args.logger.tensorboard,
                    matplotlib=args.logger.matplotlib,
                    log_dir=args.log_dir)
    accuracy = 0
    # Start training

    print("Trying to train model {}".format(model))
    print("Will run up to {} epochs of training".format(
        args.train.stop_at_epoch))

    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, _data in enumerate(local_progress):
            # TODO looks like we might be missing the label?
            ((images1, images2), labels) = _data

            model.zero_grad()
            data_dict = model.forward(images1.to(device, non_blocking=True),
                                      images2.to(device, non_blocking=True))
            loss = data_dict['loss'].mean()  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        # ignore KNN monitor since it's coded to work ONLY on cuda enabled devices unfortunately
        # check the mnist yaml to see
        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   device,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

    # Save checkpoint
    model_path = os.path.join(
        args.ckpt_dir,
        f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth"
    )  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')

    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
Exemplo n.º 3
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)

    train_dataset = get_dataset(transform=get_aug(train=True,
                                                  **args.aug_kwargs),
                                train=True,
                                **args.dataset_kwargs)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        sampler=train_sampler,
        **args.dataloader_kwargs)

    memory_dataset = get_dataset(transform=get_aug(train=False,
                                                   train_classifier=False,
                                                   **args.aug_kwargs),
                                 train=True,
                                 **args.dataset_kwargs)

    memory_loader = torch.utils.data.DataLoader(
        dataset=memory_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)

    test_datset = get_dataset(transform=get_aug(train=False,
                                                train_classifier=False,
                                                **args.aug_kwargs),
                              train=False,
                              **args.dataset_kwargs)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_datset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)
    print("Batch size:", (args.train.batch_size // args.gpus))
    # define model
    model = get_model(args.model).cuda(gpu)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )
    if gpu == 0:
        logger = Logger(tensorboard=args.logger.tensorboard,
                        matplotlib=args.logger.matplotlib,
                        log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            data_dict = model.forward(images1.cuda(non_blocking=True),
                                      images2.cuda(non_blocking=True))
            loss = data_dict['loss']  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})
            local_progress.set_postfix(data_dict)
            if gpu == 0:
                logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0 and gpu == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   gpu,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)

        if gpu == 0:
            logger.update_scalers(epoch_dict)

        # Save checkpoint
        if gpu == 0 and epoch % args.train.knn_interval == 0:
            model_path = os.path.join(
                args.ckpt_dir, f"{args.name}_{epoch+1}.pth"
            )  # datetime.now().strftime('%Y%m%d_%H%M%S')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict()
                }, model_path)
            print(f"Model saved to {model_path}")
            with open(os.path.join(args.log_dir, f"checkpoint_path.txt"),
                      'w+') as f:
                f.write(f'{model_path}')

    # if args.eval is not False and gpu == 0:
    #     args.eval_from = model_path
    #     linear_eval(args)

    dist.destroy_process_group()