Example #1
0
def init_optim(
    model: type[BaseModelClass],
    optim: type[Optimizer] | Literal["SGD", "Adam", "AdamW"],
    learning_rate: float,
    weight_decay: float,
    momentum: float,
    device: type[torch.device] | Literal["cuda", "cpu"],
    milestones: Iterable = (),
    gamma: float = 0.3,
    resume: str = None,
    **kwargs,
) -> tuple[Optimizer, _LRScheduler]:
    """Initialize Optimizer and Scheduler.

    Args:
        model (type[BaseModelClass]): Model to be optimized.
        optim (type[Optimizer] | "SGD" | "Adam" | "AdamW"): Which optimizer to use
        learning_rate (float): Learning rate for optimization
        weight_decay (float): Weight decay for optimizer
        momentum (float): Momentum for optimizer
        device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
        milestones (Iterable, optional): When to decay learning rate. Defaults to ().
        gamma (float, optional): Multiplier for learning rate decay. Defaults to 0.3.
        resume (str, optional): Path to model checkpoint to resume. Defaults to None.


    Returns:
        tuple[Optimizer, _LRScheduler]: Optimizer and scheduler for given model
    """
    # Select Optimiser
    if optim == "SGD":
        optimizer = SGD(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            momentum=momentum,
        )
    elif optim == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=learning_rate,
                         weight_decay=weight_decay)
    elif optim == "AdamW":
        optimizer = AdamW(model.parameters(),
                          lr=learning_rate,
                          weight_decay=weight_decay)
    else:
        raise NameError("Only SGD, Adam or AdamW are allowed as --optim")

    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    if resume:
        # TODO work out how to ensure that we are using the same optimizer
        # when resuming such that the state dictionaries do not clash.
        # TODO breaking the function apart means we load the checkpoint twice.
        checkpoint = torch.load(resume, map_location=device)
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    return optimizer, scheduler
Example #2
0
def main():
    """Main pipeline."""
    # parse command line arguments
    pathSet = [
        "./myProject/trainData/Mild/", "./myProject/trainData/Aggressive/"
    ]
    for dataPath in pathSet:
        # parse argument
        args = parse_args(dataPath)

        print("Constructing data loaders...")
        # load training set and split
        trainset, valset = load_data(args.dataroot, args.train_size)

        trainloader, validationloader = data_loader(args.dataroot, trainset,
                                                    valset, args.batch_size,
                                                    args.shuffle,
                                                    args.num_workers)

        # define model
        print("Initialize model ...")
        if args.model_name == "default":
            model = NetworkNvidia()
        elif args.model_name == "modified":
            model = ModNet()

        # define optimizer and loss function
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
        criterion = nn.MSELoss()

        # learning rate scheduler
        scheduler = MultiStepLR(optimizer, milestones=[30, 50], gamma=0.1)

        # resume
        if args.resume:
            print("Loading a model from checkpoint")
            # use pre-trained model
            checkpoint = torch.load("model.h5",
                                    map_location=lambda storage, loc: storage)

            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])

        # cuda or cpu
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print("Selected GPU: ", device)

        # training
        print("Training Neural Network...")
        trainer = Trainer(args.dataroot, args.ckptroot, model, device,
                          args.epochs, criterion, optimizer, scheduler,
                          args.start_epoch, trainloader, validationloader)
        trainer.train()
def main():
    """Main pipeline."""
    # parse command line arguments
    args = parse_args()

    # load trainig set and split
    trainset, valset = load_data(args.dataroot, args.train_size)

    print("==> Preparing dataset ...")
    trainloader, validationloader = data_loader(args.dataroot, trainset,
                                                valset, args.batch_size,
                                                args.shuffle, args.num_workers)

    # define model
    print("==> Initialize model ...")
    if args.model_name == "nvidia":
        model = NetworkNvidia()
    elif args.model_name == "light":
        model = NetworkLight()

    # define optimizer and criterion
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = nn.MSELoss()

    # learning rate scheduler
    scheduler = MultiStepLR(optimizer, milestones=[30, 50], gamma=0.1)

    # resume
    if args.resume:
        print("==> Loading checkpoint ...")
        # use pre-trained model
        checkpoint = torch.load("model.h5",
                                map_location=lambda storage, loc: storage)

        print("==> Loading checkpoint model successfully ...")
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    # cuda or cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("==> Use accelerator: ", device)

    # training
    print("==> Start training ...")
    trainer = Trainer(args.ckptroot, model, device, args.epochs, criterion,
                      optimizer, scheduler, args.start_epoch, trainloader,
                      validationloader)
    trainer.train()
Example #4
0
def resume(args, dataset, device):
    """ Loads model and optimizer state from a previous checkpoint. """
    print("=> loading checkpoint '{}'".format(args.resume))
    checkpoint = torch.load(args.resume, map_location=str(device))

    model = create_model(checkpoint['args'], dataset, device)
    model.load_state_dict(checkpoint['state_dict'])

    optimizer = create_optimizer(args, model)
    optimizer.load_state_dict(checkpoint['optimizer'])
    args.start_epoch = checkpoint['epoch']

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_steps,
                            gamma=args.lr_decay)
    scheduler.load_state_dict(checkpoint['scheduler'])

    return model, optimizer, scheduler
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    model = models.__dict__[args.arch](
        num_keypoints=train_source_dataset.num_keypoints).to(device)
    criterion = JointsMSELoss()

    # define optimizer and lr scheduler
    optimizer = Adam(model.get_parameters(lr=args.lr))
    lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        lr_scheduler.step()

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              optimizer, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
Example #6
0
# Load Checkpoints
start_epoch = 0
if os.path.exists(CHECKPOINT_DIR) == False:
    os.mkdir(CHECKPOINT_DIR)
generator_checkpoint_file = os.path.join(CHECKPOINT_DIR,
                                         'checkpoint_generator.tar')
discriminator_checkpoint_file = os.path.join(CHECKPOINT_DIR,
                                             'checkpoint_discriminator.tar')
if os.path.isfile(generator_checkpoint_file) and os.path.isfile(
        discriminator_checkpoint_file):
    generator_checkpoint = torch.load(generator_checkpoint_file)
    generator.load_state_dict(generator_checkpoint['model_state_dict'])
    generator_optimizer.load_state_dict(
        generator_checkpoint['optimizer_state_dict'])
    G_epoch = generator_checkpoint['epoch']
    generator_lr_scheduler.load_state_dict(generator_checkpoint['scheduler'])
    print("Load checkpoint {} (epoch {})".format(generator_checkpoint_file,
                                                 G_epoch))
    discriminator_checkpoint = torch.load(discriminator_checkpoint_file)
    discriminator.load_state_dict(discriminator_checkpoint['model_state_dict'])
    discriminator_optimizer.load_state_dict(
        discriminator_checkpoint['optimizer_state_dict'])
    D_epoch = discriminator_checkpoint['epoch']
    discriminator_lr_scheduler.load_state_dict(
        discriminator_checkpoint['scheduler'])
    print("Load checkpoint {} (epoch {})".format(discriminator_checkpoint_file,
                                                 D_epoch))
    assert G_epoch == D_epoch
    start_epoch = G_epoch

# Data Parallelism
Example #7
0
def main():

    # 1. argparser
    opts = parse(sys.argv[1:])
    print(opts)

    # 3. visdom
    vis = visdom.Visdom(port=opts.port)
    # 4. data set
    train_set = None
    test_set = None

    # train_set = KorEngDataset(root='./data', split='train')
    train_set = KorEngDataset(root='./data', split='valid')

    # 5. data loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opts.batch_size,
                                               collate_fn=train_set.collate_fn,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

    # test_loader = torch.utils.data.DataLoader(test_set,
    #                                           batch_size=1,
    #                                           collate_fn=test_set.collate_fn,
    #                                           shuffle=False,
    #                                           num_workers=2,
    #                                           pin_memory=True)

    # 6. network
    model = Transformer(num_vocab=110000,
                        model_dim=512,
                        max_seq_len=64,
                        num_head=8,
                        num_layers=6,
                        dropout=0.1).to(device)
    model = torch.nn.DataParallel(module=model, device_ids=device_ids)

    # 7. loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=opts.momentum,
                                weight_decay=opts.weight_decay)

    # 9. scheduler
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=[30, 45],
                            gamma=0.1)

    # 10. resume
    if opts.start_epoch != 0:

        checkpoint = torch.load(
            os.path.join(opts.save_path, opts.save_file_name) +
            '.{}.pth.tar'.format(opts.start_epoch - 1),
            map_location=device)  # 하나 적은걸 가져와서 train
        model.load_state_dict(
            checkpoint['model_state_dict'])  # load model state dict
        optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])  # load optim state dict
        scheduler.load_state_dict(
            checkpoint['scheduler_state_dict'])  # load sched state dict
        print('\nLoaded checkpoint from epoch %d.\n' %
              (int(opts.start_epoch) - 1))

    else:

        print('\nNo check point to resume.. train from scratch.\n')

    # for statement
    for epoch in range(opts.start_epoch, opts.epoch):

        # 11. train
        train(epoch=epoch,
              vis=vis,
              train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              scheduler=scheduler,
              opts=opts)

        scheduler.step()
Example #8
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cudnn.benchmark = True

    start_epoch = args.start_epoch

    lr_decay_step = list(map(int, args.lr_decay_step.split(',')))

    # Data loading
    print_logger.info('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    num_classes = 0
    if args.dataset in ['cifar10']:
        num_classes = 10

    model = eval(args.block_type + 'ResNet56_od')(
        groups=args.group_num,
        expansion=args.expansion,
        num_stu=args.num_stu,
        num_classes=num_classes).cuda()

    if len(args.gpu) > 1:
        device_id = []
        for i in range((len(args.gpu) + 1) // 2):
            device_id.append(i)
        model = torch.nn.DataParallel(model, device_ids=device_id)

    best_prec = 0.0

    if not model:
        print_logger.info("Model arch Error")
        return

    print_logger.info(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer,
                            milestones=lr_decay_step,
                            gamma=args.lr_decay_factor)

    # Optionally resume from a checkpoint
    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume)
        state_dict = checkpoint['state_dict']
        if args.adjust_ckpt:
            new_state_dict = {
                k.replace('module.', ''): v
                for k, v in state_dict.items()
            }
        else:
            new_state_dict = state_dict

        if args.start_epoch == 0:
            start_epoch = checkpoint['epoch']

        best_prec = checkpoint['best_prec']
        model.load_state_dict(new_state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    if args.test_only:
        test_prec = test(args, loader.loader_test, model)
        print('=> Test Prec@1: {:.2f}'.format(test_prec[0]))
        return

    record_top5 = 0.
    for epoch in range(start_epoch, args.epochs):

        scheduler.step(epoch)

        train_loss, train_prec = train(args, loader.loader_train, model,
                                       criterion, optimizer, epoch)
        test_prec = test(args, loader.loader_test, model, epoch)

        is_best = best_prec < test_prec[0]
        if is_best:
            record_top5 = test_prec[1]
        best_prec = max(test_prec[0], best_prec)

        state = {
            'state_dict': model.state_dict(),
            'test_prec': test_prec[0],
            'best_prec': best_prec,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        if epoch % args.save_freq == 0 or is_best:
            ckpt.save_model(state, epoch + 1, is_best)
        print_logger.info("=>Best accuracy {:.3f}, {:.3f}".format(
            best_prec, record_top5))
Example #9
0
def main():
    # print the experiment configuration
    print('\33[91m\nCurrent time is {}. \33[0m'.format(str(time.asctime())))
    print('Parsed options: {}.'.format(vars(args)))
    print('Number of Speakers: {}\n'.format(len(train_dir.classes)))

    # instantiate model and initialize weights
    model = ResCNNSpeaker(embedding_size=args.embedding_size,
                          resnet_size=10,
                          num_classes=len(train_dir.classes))

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model, args.optimizer, **opt_kwargs)
    # criterion = AngularSoftmax(in_feats=args.embedding_size,
    #                           num_classes=len(train_dir.classes))
    scheduler = MultiStepLR(optimizer, milestones=[20, 30], gamma=0.1)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])

            try:
                scheduler.load_state_dict(checkpoint['scheduler'])
            except:
                print('No scheduler found!')
            # criterion.load_state_dict(checkpoint['criterion'])

        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start = args.start_epoch
    print('start epoch is : ' + str(start))
    # start = 0
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(
        train_dir,
        batch_size=args.batch_size,
        shuffle=True,
        # collate_fn=PadCollate(dim=2),
        **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        valid_dir,
        batch_size=args.test_batch_size,
        shuffle=False,
        # collate_fn=PadCollate(dim=2),
        **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dir,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    ce = nn.CrossEntropyLoss().cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        train(train_loader, model, ce, optimizer, epoch)
        test(test_loader, valid_loader, model, epoch)
        scheduler.step()
        # break

    writer.close()
Example #10
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes.
            args.rank = args.rank * ngpus_per_node + gpu

        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank
        )

    # load model here
    # model = maskrcnn001(num_classes=2)

    model = arch(num_classes=2)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all availabel devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per 
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set.
            model = DistributedDataParallel(model) 
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divice and allocate batch_size to all availabel GPUs
        # model = torch.nn.DataParallel(model).cuda()
        model = model.cuda()

    if args.distributed:
        # model = DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model_without_ddp = model

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # lr_scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)
    lr_scheduler = MultiStepLR(optimizer, milestones=[20000, 40000], gamma=0.1)

    # ================================
    # resume RESUME CHECKPOINT
    if IS_SM:  # load latest checkpoints 
        checkpoint_list = os.listdir(checkpoint_dir)

        logger.info("=> Checking checkpoints dir.. {}".format(checkpoint_dir))
        logger.info(checkpoint_list)

        latest_path_parent = ""
        latest_path = ""
        latest_iter_num = -1

        for checkpoint_path in natsorted(glob.glob(os.path.join(checkpoint_dir, "*.pth"))):
            checkpoint_name = os.path.basename(checkpoint_path)
            logger.info("Found checkpoint {}".format(checkpoint_name))
            iter_num = int(os.path.splitext(checkpoint_name)[0].split("_")[-1])

            if iter_num > latest_iter_num:
                latest_path_parent = latest_path
                latest_path = checkpoint_path
                latest_iter_num = iter_num 

        logger.info("> latest checkpoint is {}".format(latest_path))

        if latest_path_parent:
            logger.info("=> loading checkpoint {}".format(latest_path_parent))
            checkpoint = torch.load(latest_path_parent, map_location="cpu")
            model_without_ddp.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

            args.start_epoch = checkpoint["epoch"]
            args.iter_num = checkpoint["iter_num"]
            logger.info("==> args.iter_num is {}".format(args.iter_num))

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return
    
    logger.info("==================================")
    logger.info("Create dataset with root_dir={}".format(args.train_data_path))
    assert os.path.exists(args.train_data_path), "root_dir does not exists!"
    train_set = TableBank(root_dir=args.train_data_path)

    if args.distributed:
        train_sampler = DistributedSampler(train_set)
    else:
        train_sampler = RandomSampler(train_set)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            train_set,
            k=args.aspect_ratio_group_factor
        )
        train_batch_sampler = GroupedBatchSampler(
            train_sampler,
            group_ids,
            args.batch_size
        )
    else:
        train_batch_sampler = BatchSampler(
            train_sampler,
            args.batch_size,
            drop_last=True
        )

    logger.info("Create data_loader.. with batch_size = {}".format(args.batch_size))
    train_loader = DataLoader(
        train_set,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        pin_memory=True
    )

    logger.info("Start training.. ")

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_one_epoch(
            model=model,
            arch=arch,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            data_loader=train_loader,
            device=args.gpu,
            epoch=epoch,
            print_freq=args.print_freq,
            ngpus_per_node=4,
            model_without_ddp=model_without_ddp,
            args=args
        )
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    model = AttenSiResNet(layers=[3, 4, 6, 3], num_classes=len(train_dir.speakers))

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[18, 24], gamma=0.1)
    # criterion = AngularSoftmax(in_feats=args.embedding_size,
    #                           num_classes=len(train_dir.classes))
    start = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = start + args.epochs

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size,
                                               # collate_fn=PadCollate(dim=2, fix_len=True),
                                               shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size,
                                               # collate_fn=PadCollate(dim=2, fix_len=True),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    criterion = nn.CrossEntropyLoss().cuda()
    check_path = '{}/checkpoint_{}.pth'.format(args.check_path, -1)
    torch.save({'epoch': -1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()},
               # 'criterion': criterion.state_dict()
               check_path)

    for epoch in range(start, end):
        # pdb.set_trace()
        for param_group in optimizer.param_groups:
            print('\n\33[1;34m Current \'{}\' learning rate is {}.\33[0m'.format(args.optimizer, param_group['lr']))

        train(train_loader, model, optimizer, criterion, scheduler, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break

    writer.close()
Example #12
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda
    train_samples = 118287

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='smddp', init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(seed=args.seed)


    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=ResNet(args.backbone, args.backbone_path))
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size / 32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300), lr=args.learning_rate,
                                    momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer, milestones=args.multistep, gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300, args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer, train_loader, val_dataloader, encoder, iteration,
                                    logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if torch.distributed.get_rank() == 0:
            throughput = train_samples / end_epoch_time
            logger.update_epoch_time(epoch, end_epoch_time)
            logger.update_throughput_speed(epoch, throughput)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {'epoch': epoch + 1,
                   'iteration': iteration,
                   'optimizer': optimizer.state_dict(),
                   'scheduler': scheduler.state_dict(),
                   'label_map': val_dataset.label_info}
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            save_path = os.path.join(args.save, f'epoch_{epoch}.pt')
            torch.save(obj, save_path)
            logger.log('model path', save_path)
        train_loader.reset()

    if torch.distributed.get_rank() == 0:
        DLLogger.log((), { 'Total training time': '%.2f' % total_time + ' secs' })
        logger.log_summary()
Example #13
0
def main():
    alpha = configs.v_loss_rate
    beta = configs.k_loss_rate
    gama = configs.cls_loss_rate
    # for exp
    for_exp = True

    pwd = os.getcwd()
    save_path = os.path.join(
        pwd, configs.save_path, "%d_mem_size_%d_way_%d_shot_%d_query" %
        (configs.mem_size, configs.n_way, configs.k_shot, configs.k_query))

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    model_path = os.path.join(save_path, "model")
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # init dataloader
    print("init data loader")
    train_db, total_train = data_loader(configs,
                                        num_workers=configs.num_workers,
                                        split="train",
                                        use_dali=configs.use_dali)
    val_db, total_valid = data_loader(configs,
                                      num_workers=configs.num_workers,
                                      split="val",
                                      use_dali=configs.use_dali)

    # init neural networks
    backbone, gnn, mem = build_net()
    params = list(backbone.parameters()) + list(gnn.parameters()) + list(
        mem.parameters())

    print(repr(configs))
    # print(backbone, gnn, mem)

    # optimizer
    if configs.train_optim == 'adam':
        optimizer = optim.Adam(params,
                               lr=configs.lr,
                               weight_decay=configs.weight_decay)
    elif configs.train_optim == 'sgd':
        optimizer = optim.SGD(params,
                              lr=configs.lr,
                              weight_decay=configs.weight_decay,
                              momentum=configs.momentum)
    elif configs.train_optim == 'rmsprop':
        optimizer = optim.RMSprop(params,
                                  lr=configs.lr,
                                  weight_decay=configs.weight_decay,
                                  momentum=configs.momentum,
                                  alpha=0.9,
                                  centered=True)
    else:
        raise Exception("error optimizer")

    # learning rate decay policy
    if configs.lr_policy == 'multi_step':
        scheduler = MultiStepLR(optimizer,
                                milestones=list(
                                    map(int, configs.milestones.split(','))),
                                gamma=configs.lr_gama)
    elif configs.lr_policy == 'exp':
        scheduler = ExponentialLR(optimizer, gamma=configs.lr_gamma)
    elif configs.lr_policy == 'plateau':
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode="min",
                                      factor=configs.lr_gama,
                                      patience=80,
                                      verbose=True)
    else:
        raise Exception('error lr decay policy')

    if configs.start_epoch:
        check_point = torch.load(
            os.path.join(model_path, "%d_model.pkl" % configs.start_epoch))
        backbone.load_state_dict(check_point['backbone_state_dict'])
        gnn.load_state_dict(check_point['gnn_state_dict'])
        mem.load_state_dict(check_point['mem_state_dict'])
        if for_exp:
            scheduler.load_state_dict(check_point['scheduler'])
            optimizer.load_state_dict(check_point['optimizer'])
        print('Loading Parameters from %d_model.pkl' % configs.start_epoch)

    # Train and validation
    best_acc = 0.0
    best_loss = np.inf
    wait = 0

    writer = SummaryWriter(
        os.path.join(save_path, "logs",
                     "%s" % time.strftime('%Y-%m-%d-%H-%M')))

    for ep in range(configs.start_epoch, configs.epochs):
        margin = 2

        thresh_train = [
            1 * (1 / (1 + exp(-ep / 100 + margin))),
            1 - (1 / (1 + exp(-ep / 100 + margin))),
            1 * (1 / (1 + exp(-ep / 100 + margin))),
            1 - (1 / (1 + exp(-ep / 100 + margin))),
        ]
        print("epoch:", ep, "thresh_train:", thresh_train)
        loss_print = defaultdict(list)

        train_loss_item = 0
        train_acc_item = 0
        train_loss_k = 0
        train_loss_v = 0
        train_loss_c = 0

        train_pbar = tqdm(train_db, total=total_train)
        for step, train_data in enumerate(train_pbar):
            train_pbar.set_description(
                'train_epoc:{}, total_loss:{:.5f}, acc:{:.5f}, loss_k:{:.5f}, loss_v:{:.5f}, loss_c:{:.5f}'
                .format(ep, train_loss_item, train_acc_item, train_loss_k,
                        train_loss_v, train_loss_c))

            # start to train
            backbone.train()
            gnn.train()
            mem.train()
            support_x, support_y, query_x, query_y = decompose_to_input(
                train_data)
            # propagation
            embedding, global_embedding = backbone(
                torch.cat([support_x, query_x], 0) / 255.0)
            support_embedding, query_embedding, support_y, query_y, sup_emb_glo, que_emb_glo = \
                decompose_embedding_from_backbone(embedding, [support_y, query_y], global_embedding)
            embedding, global_embedding, loss_k, loss_v, loss_s = mem(
                [support_embedding, query_embedding],
                [sup_emb_glo, que_emb_glo], thresh_train)
            loss_cls, acc = gnn(embedding, global_embedding,
                                [support_y, query_y])
            loss = alpha * loss_v + beta * loss_k + gama * loss_cls + loss_s
            # for visual images and labels
            '''
            imgs = torch.cat([support_x, query_x], 0)
            labels = torch.cat([support_y, query_y], 0)
            labels = torch.argmax(labels, -1)
            import matplotlib.pyplot as plt
            import matplotlib.gridspec as gridspec
            from collections import defaultdict
            rows = configs.n_way
            batch_size = imgs.size(0)
            cols = batch_size // rows
            gs = gridspec.GridSpec(rows, cols)
            fig = plt.figure(figsize=(84 * cols, 84 * rows), dpi=2)
            plt.rc('font', size=8)
            nums = defaultdict(int)
            for j in range(batch_size):
                label = int(labels[j] + 0)
                plt.subplot(gs[label*cols+nums[label]])
                nums[label] += 1
                plt.axis('off')
                img = imgs[j].type(torch.uint8).permute(1, 2, 0).cpu().numpy()
                plt.imshow(img)
            print(repr(labels))
            plt.savefig('test.jpg')
            '''

            train_loss_item = loss.item()
            train_acc_item = acc.item()
            train_loss_c = loss_cls.item() * gama
            train_loss_k = loss_k.item() * alpha
            train_loss_v = loss_v.item() * beta
            loss_print["train_loss"].append(train_loss_item)
            loss_print["train_loss_c"].append(train_loss_c)
            loss_print["train_loss_k"].append(train_loss_k)
            loss_print["train_loss_v"].append(train_loss_v)
            loss_print["train_acc"].append(train_acc_item)

            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm(model.parameters(), 4.0)
            optimizer.step()

        # for valid
        valid_loss_item = 0
        valid_acc_item = 0
        valid_loss_k = 0
        valid_loss_v = 0
        valid_loss_c = 0
        valid_pbar = tqdm(val_db, total=total_valid)
        for step, train_data in enumerate(valid_pbar):
            valid_pbar.set_description(
                'valid_epoc:{}, total_loss:{:.5f}, acc:{:.5f}, loss_k:{:.5f}, loss_v:{:.5f}, loss_c:{:.5f}'
                .format(ep, valid_loss_item, valid_acc_item, valid_loss_k,
                        valid_loss_v, valid_loss_c))
            # start to valid
            backbone.eval()
            gnn.eval()
            mem.eval()
            support_x, support_y, query_x, query_y = decompose_to_input(
                train_data)
            # propagation
            with torch.no_grad():
                embedding, global_embedding = backbone(
                    torch.cat([support_x, query_x], 0) / 255.0)
                support_embedding, query_embedding, support_y, query_y, sup_emb_glo, que_emb_glo = \
                    decompose_embedding_from_backbone(embedding, [support_y, query_y], global_embedding)
                embedding, global_embedding, loss_k, loss_v, loss_s = mem(
                    [support_embedding, query_embedding],
                    [sup_emb_glo, que_emb_glo], thresh_train)
                loss_cls, acc = gnn(embedding, global_embedding,
                                    [support_y, query_y])
                loss = alpha * loss_v + beta * loss_k + gama * loss_cls + loss_s

            valid_loss_item = loss.item()
            valid_acc_item = acc.item()
            valid_loss_c = loss_cls.item() * gama
            valid_loss_k = loss_k.item() * alpha
            valid_loss_v = loss_v.item() * beta
            loss_print["valid_loss"].append(train_loss_item)
            loss_print["valid_loss_c"].append(valid_loss_c)
            loss_print["valid_loss_k"].append(valid_loss_k)
            loss_print["valid_loss_v"].append(valid_loss_v)
            loss_print["valid_acc"].append(valid_acc_item)

        scheduler.step(np.mean(loss_print["valid_loss"]))
        print('epoch:{}, lr:{:.6f}'.format(ep,
                                           optimizer.param_groups[0]['lr']))
        print([
            "{}:{:.6f}".format(key, np.mean(loss_print[key]))
            for key in loss_print.keys()
        ])

        # tensorboard
        # writer.add_graph(net, (inputs,))
        for key in loss_print.keys():
            writer.add_scalar(key, np.mean(loss_print[key]), ep)
        writer.add_scalar('Loss/train', np.mean(loss_print["train_loss"]), ep)
        writer.add_scalar('Loss/val', np.mean(loss_print["valid_loss"]), ep)
        writer.add_scalar('Accuracy/train', np.mean(loss_print["train_acc"]),
                          ep)
        writer.add_scalar('Accuracy/val', np.mean(loss_print["valid_acc"]), ep)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], ep)

        # Model Save and Stop Criterion
        cond1 = (np.mean(loss_print["valid_acc"]) > best_acc)
        cond2 = (np.mean(loss_print["valid_loss"]) < best_loss)

        if cond1 or cond2:
            best_acc = np.mean(loss_print["valid_acc"])
            best_loss = np.mean(loss_print["valid_loss"])
            print('best val loss:{:.5f}, acc:{:.5f}'.format(
                best_loss, best_acc))

            # save model
            torch.save(
                save_state(for_exp, backbone, gnn, mem, optimizer, scheduler),
                os.path.join(save_path, "model", '%d_model.pkl' % ep))
            wait = 0

        else:
            wait += 1
            if (ep + 1) % 100 == 0:
                torch.save(
                    save_state(for_exp, backbone, gnn, mem, optimizer,
                               scheduler),
                    os.path.join(save_path, "model", '%d_model.pkl' % ep))

        if wait > configs.patience:
            break
Example #14
0
class AuxModel:
    def __init__(self, config, logger, wandb):
        self.config = config
        self.logger = logger
        self.writer = SummaryWriter(config.log_dir)
        self.wandb = wandb
        cudnn.enabled = True

        # set up model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = get_model(config)
        if len(config.gpus) > 1:
            self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)
        self.best_acc = 0
        self.best_AUC = 0
        self.class_loss_func = nn.CrossEntropyLoss()
        self.pixel_loss = nn.L1Loss()
        if config.mode == 'train':
            # set up optimizer, lr scheduler and loss functions
            lr = config.lr
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=lr,
                                              betas=(.5, .999))
            self.scheduler = MultiStepLR(self.optimizer,
                                         milestones=[50, 150],
                                         gamma=0.1)
            self.wandb.watch(self.model)
            self.start_iter = 0

            # resume
            if config.training_resume:
                self.load(config.model_dir + '/' + config.training_resume)

            cudnn.benchmark = True
        elif config.mode == 'val':
            self.load(os.path.join(config.testing_model))
        else:
            self.load(os.path.join(config.testing_model))

    def entropy_loss(self, x):
        return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()

    def train_epoch_main_task(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()

        for it, src_batch in enumerate(src_loader['main_task']):
            t = time.time()
            self.optimizer.zero_grad()
            src = src_batch
            src = to_device(src, self.device)
            src_imgs, src_cls_lbls = src
            self.optimizer.zero_grad()
            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))

            loss.backward()
            self.optimizer.step()

            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)

            self.start_iter += 1

            if self.start_iter % print_freq == 0:
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f}| acc:{:.3f}| src_main: {:.3f} |' + '|{:4.2f} s/it'
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
        self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})

        # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
        # del src_aux_logits, src_class_logits
        # del tar_aux_logits, tar_class_logits

    def train_epoch_all_tasks(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()
        start_steps = epoch * len(tar_loader['main_task'])
        total_steps = self.config.num_epochs * len(tar_loader['main_task'])

        max_num_iter_src = max([
            len(src_loader[task_name]) for task_name in self.config.task_names
        ])
        for it in range(max_num_iter_src):
            t = time.time()

            # this is based on DANN paper
            p = float(it + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            self.optimizer.zero_grad()

            src = next(iter(src_loader['main_task']))
            tar = next(iter(tar_loader['main_task']))
            src = to_device(src, self.device)
            tar = to_device(tar, self.device)
            src_imgs, src_cls_lbls = src
            tar_imgs, _ = tar

            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            tar_main_logits = self.model(tar_imgs, 'main_task')
            tar_main_loss = self.entropy_loss(tar_main_logits)
            loss += tar_main_loss
            tar_aux_loss = {}
            src_aux_loss = {}

            #TO DO: separating dataloaders and iterate over tasks
            for task in self.config.task_names:
                if self.config.tasks[task]['type'] == 'classification_adapt':
                    r = torch.randperm(src_imgs.size()[0] + tar_imgs.size()[0])
                    src_tar_imgs = torch.cat((src_imgs, tar_imgs), dim=0)
                    src_tar_imgs = src_tar_imgs[r, :, :, :]
                    src_tar_img = src_tar_imgs[:src_imgs.size()[0], :, :, :]
                    src_tar_lbls = torch.cat((torch.zeros(
                        (src_imgs.size()[0])), torch.ones(
                            (tar_imgs.size()[0]))),
                                             dim=0)
                    src_tar_lbls = src_tar_lbls[r]
                    src_tar_lbls = src_tar_lbls[:src_imgs.size()[0]]
                    src_tar_lbls = src_tar_lbls.long().cuda()
                    src_tar_logits = self.model(src_tar_img,
                                                'domain_classifier', alpha)
                    tar_aux_loss['domain_classifier'] = self.class_loss_func(
                        src_tar_logits, src_tar_lbls)
                    loss += tar_aux_loss[
                        'domain_classifier'] * self.config.loss_weight[
                            'domain_classifier']
                if self.config.tasks[task]['type'] == 'classification_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_logits = self.model(tar_aux_imgs, task)
                    src_aux_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.class_loss_func(
                        tar_aux_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.class_loss_func(
                        src_aux_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: main task weight
                if self.config.tasks[task]['type'] == 'pixel_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_mag_logits = self.model(tar_aux_imgs, task)
                    src_aux_mag_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.pixel_loss(
                        tar_aux_mag_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.pixel_loss(
                        src_aux_mag_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[task]

            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))
            loss.backward()
            self.optimizer.step()
            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)
            self.start_iter += 1
            if self.start_iter % print_freq == 0:
                printt = ''
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        printt = printt + ' | tar_aux_' + task_name + ': {:.3f} |'
                    else:
                        printt = printt + 'src_aux_' + task_name + ': {:.3f} | tar_aux_' + task_name + ': {:.3f}'
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f} |  acc: {:.3f} | src_main: {:.3f} |' + printt + '{:4.2f} s/it'
                src_aux_loss_all = [
                    loss.item() for loss in src_aux_loss.values()
                ]
                tar_aux_loss_all = [
                    loss.item() for loss in tar_aux_loss.values()
                ]
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        *src_aux_loss_all, *tar_aux_loss_all,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        # self.writer.add_scalar('losses/src_aux_loss_'+task_name, src_aux_loss[task_name], i_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
                    else:
                        self.writer.add_scalar(
                            'losses/src_aux_loss_' + task_name,
                            src_aux_loss[task_name], self.start_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
            self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})

        # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
        # del src_aux_logits, src_class_logits
        # del tar_aux_logits, tar_class_logits

    def train(self, src_loader, tar_loader, val_loader, test_loader):
        num_batches = len(src_loader['main_task'])
        print_freq = max(num_batches // self.config.training_num_print_epoch,
                         1)
        start_epoch = self.start_iter // num_batches
        num_epochs = self.config.num_epochs
        for epoch in range(start_epoch, num_epochs):
            if len(self.config.task_names) == 1:
                self.train_epoch_main_task(src_loader, tar_loader, epoch,
                                           print_freq)
            else:
                self.train_epoch_all_tasks(src_loader, tar_loader, epoch,
                                           print_freq)
            self.logger.info('learning rate: %f ' % get_lr(self.optimizer))
            # validation
            self.save(self.config.model_dir, 'last')

            if val_loader is not None:
                self.logger.info('validating...')
                class_acc, AUC = self.test(val_loader)
                # self.writer.add_scalar('val/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc,
                                       self.start_iter)
                if class_acc > self.best_acc:
                    self.best_acc = class_acc
                    self.save(self.config.best_model_dir, 'best_acc')
                if AUC > self.best_AUC:
                    self.best_AUC = AUC
                    self.save(self.config.best_model_dir, 'best_AUC')
                    # todo copy current model to best model
                self.logger.info('Best validation accuracy: {:.2f} %'.format(
                    self.best_acc))

            if test_loader is not None:
                self.logger.info('testing...')
                class_acc = self.test(test_loader)
                # self.writer.add_scalar('test/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc,
                                       self.start_iter)
                # if class_acc > self.best_acc:
                #     self.best_acc = class_acc
                # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(class_acc))

        self.logger.info('Best validation accuracy: {:.2f} %'.format(
            self.best_acc))
        self.logger.info('Finished Training.')

    def save(self, path, ext):
        state = {
            "iter": self.start_iter + 1,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_acc": self.best_acc,
        }
        save_path = os.path.join(path, f'model_{ext}.pth')
        self.logger.info('Saving model to %s' % save_path)
        torch.save(state, save_path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state'])
        self.logger.info('Loaded model from: ' + path)

        if self.config.mode == 'train':
            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state'])
            self.start_iter = checkpoint['iter']
            self.best_acc = checkpoint['best_acc']
            self.logger.info('Start iter: %d ' % self.start_iter)

    def test(self, val_loader):
        val_loader_iterator = iter(val_loader)
        num_val_iters = len(val_loader)
        tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")
        loss = AverageMeter()
        kk = 1
        aux_correct = 0
        class_correct = 0
        total = 0
        if self.config.dataset == 'kather':
            soft_labels = np.zeros((1, 9))
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            soft_labels = np.zeros((1, 2))
        true_labels = []
        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                data = next(val_loader_iterator)
                data = to_device(data, self.device)
                imgs, cls_lbls = data
                # Get the inputs
                logits = self.model(imgs, 'main_task')
                test_loss = self.class_loss_func(logits, cls_lbls)
                loss.update(test_loss.item(), imgs.size(0))
                if self.config.save_output == True:
                    smax = nn.Softmax(dim=1)
                    smax_out = smax(logits)
                    soft_labels = np.concatenate(
                        (soft_labels, smax_out.cpu().numpy()), axis=0)
                    true_labels = np.append(true_labels,
                                            cls_lbls.cpu().numpy())
                    pred_trh = smax_out.cpu().numpy()[:, 1]
                    pred_trh[pred_trh >= 0.5] = 1
                    pred_trh[pred_trh < 0.5] = 0
                    compare = cls_lbls.cpu().numpy() - pred_trh

                    kk += 1
                _, cls_pred = logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls)
                total += imgs.size(0)

            tt.close()
        self.wandb.log({"Test Loss": loss.avg})
        # if self.config.save_output == True:
        soft_labels = soft_labels[1:, :]
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 2,
                                 self.config.class_names,
                                 type='binary',
                                 thresh=0.5)
        if self.config.dataset == 'kather':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 9,
                                 self.config.class_names,
                                 type='multi',
                                 thresh=0.5)
        class_acc = 100 * float(class_correct) / total
        self.logger.info('class_acc: {:.2f} %'.format(class_acc))
        self.wandb.log({"Test acc": class_acc, "Test AUC": 100 * AUC})
        return class_acc, AUC
Example #15
0
        'logs/' + opt.train_dataroot + '-' + str(64) + '-' + str(opt.generatorLR) + '-' + str(opt.discriminatorLR),
        flush_secs=5)

    # ESRGAN training
    optim_generator = optim.Adam(generator.parameters(), lr=opt.generatorLR, betas=(opt.b1, opt.b2))
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=opt.discriminatorLR, betas=(opt.b1, opt.b2))
    scheduler_generator = MultiStepLR(optim_generator, milestones=[3, 7], gamma=0.5)
    scheduler_discriminator = MultiStepLR(optim_discriminator, milestones=[3, 7], gamma=0.5)

    print('ESRGAN training')
    saved_results = {'epoch': [], 'psnr': [], 'ssim': []}  # 从头开始会被抹掉
    if opt.generatorWeights != '' and opt.trian_from_scratch:
        checkpoint = torch.load(opt.generatorWeights)
        generator.load_state_dict(checkpoint['generator_model'])
        optim_generator.load_state_dict(checkpoint['generator_optimizer'])
        scheduler_generator.load_state_dict(checkpoint['scheduler_generator'])
        start_epoch = checkpoint['epoch']
        print('Load Generator epoch {} successfully!'.format(start_epoch))
    else:
        start_epoch = 0
        print('start training generator from scratch!')

    # load discriminator generator model
    if opt.discriminatorWeights != '' and opt.trian_from_scratch:
        checkpoint = torch.load(opt.discriminatorWeights)
        discriminator.load_state_dict(checkpoint['discriminator_model'])
        optim_discriminator.load_state_dict(checkpoint['discriminator_optimizer'])
        scheduler_discriminator.load_state_dict(checkpoint['scheduler_discriminator'])
        start_epoch = checkpoint['epoch']
        print('Load Discriminator epoch {} successfully!'.format(start_epoch))
    else:
Example #16
0
def main(args):
    args.cuda = torch.cuda.is_available()

    os.makedirs(args.save_path, exist_ok=True)

    ut.logger(args)
    tb_sw = SummaryWriter(log_dir=args.tensorboard_dir)
    wandb.init(project='ghkg', sync_tensorboard=True, dir=args.wandb_dir)

    e2id = ut.index('entities.dict', args)
    r2id = ut.index('relations.dict', args)
    args.nentity = len(e2id)
    args.nrelation = len(r2id)

    for k, v in sorted(vars(args).items()):
        logging.info(f'{k} = {v}')

    tr_q = ut.read(os.path.join(args.dataset, 'train.txt'), e2id, r2id)
    vd_q = ut.read(os.path.join(args.dataset, 'valid.txt'), e2id, r2id)
    ts_q = ut.read(os.path.join(args.dataset, 'test.txt'), e2id, r2id)
    logging.info(f'# Train = {len(tr_q)}')
    logging.info(f'# Valid = {len(vd_q)}')
    logging.info(f'# Test = {len(ts_q)}')

    al_q = tr_q + vd_q + ts_q

    tp_ix, tp_rix = ut.type_index(args) if args.negative_type_sampling or args.type_evaluation else (None, None)
    e_ix, u_ix = ut.users_index(args) if args.heuristic_evaluation else (None, None)

    mdl = nn.DataParallel(KGEModel(args))
    if args.cuda:
        mdl = mdl.cuda()

    logging.info('Model Parameter Configuration:')
    for name, param in mdl.named_parameters():
        logging.info(f'Parameter {name}: {param.size()}')

    ev_ix = ut.event_index(tr_q)

    if args.do_train:
        tr_dl_s = DataLoader(TrainDataset(tr_q, tp_ix, tp_rix, ev_ix, 's', args),
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=max(1, os.cpu_count() // 2))
        tr_dl_o = DataLoader(TrainDataset(tr_q, tp_ix, tp_rix, ev_ix, 'o', args),
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=max(1, os.cpu_count() // 2))
        tr_it = BidirectionalOneShotIterator(tr_dl_s, tr_dl_o)

        lr = args.learning_rate
        wd = args.weight_decay
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, mdl.parameters()), lr=lr, weight_decay=wd)
        opt_sc = MultiStepLR(opt, milestones=list(map(int, args.learning_rate_steps.split(','))))

    if args.checkpoint != '':
        logging.info(f'Loading checkpoint {args.checkpoint} ...')
        chk = torch.load(os.path.join(args.checkpoint, 'checkpoint.chk'))
        init_stp = chk['step']
        mdl.load_state_dict(chk['mdl_state_dict'])
        if args.do_train:
            lr = chk['opt_state_dict']['param_groups'][0]['lr']
            opt.load_state_dict(chk['opt_state_dict'])
            opt_sc.load_state_dict(chk['opt_sc_state_dict'])
    else:
        logging.info('Randomly Initializing ...')
        init_stp = 1

    stp = init_stp

    logging.info('Start Training ...')
    logging.info(f'init_stp = {init_stp}')

    if args.do_train:
        logging.info(f'learning_rate = {lr}')

        logs = []
        bst_mtrs = {}
        for stp in range(init_stp, args.max_steps + 1):
            log = train_step(mdl, opt, opt_sc, tr_it, args)
            logs.append(log)

            if stp % args.log_steps == 0:
                mtrs = {}
                for mtr in logs[0].keys():
                    mtrs[mtr] = sum([log[mtr] for log in logs]) / len(logs)
                ut.log('Training average', stp, mtrs)
                logs.clear()
            ut.tensorboard_scalars(tb_sw, 'train', stp, log)

            if args.do_valid and stp % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset ...')
                mtrs = test_step(mdl, vd_q, al_q, ev_ix, tp_ix, tp_rix, e_ix, u_ix, args)
                if bst_mtrs.get(args.metric, None) is None or mtrs[args.metric] > bst_mtrs[args.metric]:
                    bst_mtrs = mtrs.copy()
                    var_ls = {'step': stp}
                    ut.save(mdl, opt, opt_sc, var_ls, args)
                ut.log('Valid', stp, mtrs)
                ut.tensorboard_scalars(tb_sw, 'valid', stp, mtrs)

        ut.tensorboard_hparam(tb_sw, bst_mtrs, args)

    if args.do_eval:
        logging.info('Evaluating on Training Dataset ...')
        mtrs = test_step(mdl, tr_q, al_q, ev_ix, tp_ix, tp_rix, e_ix, u_ix, args)
        ut.log('Test', stp, mtrs)
        ut.tensorboard_scalars(tb_sw, 'eval', stp, mtrs)

    if args.do_test:
        valid_approximation, args.valid_approximation = args.valid_approximation, 0
        test_log_steps, args.test_log_steps = args.test_log_steps, 100
        logging.info('Evaluating on Test Dataset ...')
        mdl.load_state_dict(torch.load(os.path.join(args.save_path, f'checkpoint.chk'))['mdl_state_dict'])
        mtrs = test_step(mdl, ts_q, al_q, ev_ix, tp_ix, tp_rix, e_ix, u_ix, args)
        ut.log('Test', stp, mtrs)
        ut.tensorboard_scalars(tb_sw, 'test', stp, mtrs)
        args.valid_approximation = valid_approximation
        args.test_log_steps = test_log_steps

    tb_sw.flush()
    tb_sw.close()
Example #17
0
def main(opt):
    if torch.cuda.is_available():
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        num_gpus = torch.distributed.get_world_size()
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
        num_gpus = 1

    train_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    test_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": False,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    if opt.model == "ssd":
        dboxes = generate_dboxes(model="ssd")
        model = SSD(backbone=ResNet(), num_classes=len(coco_classes))
    else:
        dboxes = generate_dboxes(model="ssdlite")
        model = SSDLite(backbone=MobileNetV2(), num_classes=len(coco_classes))
    train_set = CocoDataset(opt.data_path, 2017, "train",
                            SSDTransformer(dboxes, (300, 300), val=False))
    train_loader = DataLoader(train_set, **train_params)
    test_set = CocoDataset(opt.data_path, 2017, "val",
                           SSDTransformer(dboxes, (300, 300), val=True))
    test_loader = DataLoader(test_set, **test_params)

    encoder = Encoder(dboxes)

    opt.lr = opt.lr * num_gpus * (opt.batch_size / 32)
    criterion = Loss(dboxes)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay,
                                nesterov=True)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=opt.multistep,
                            gamma=0.1)

    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

        if opt.amp:
            from apex import amp
            from apex.parallel import DistributedDataParallel as DDP
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        else:
            from torch.nn.parallel import DistributedDataParallel as DDP
        # It is recommended to use DistributedDataParallel, instead of DataParallel
        # to do multi-GPU training, even if there is only a single node.
        model = DDP(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    checkpoint_path = os.path.join(opt.save_folder, "SSD.pth")

    writer = SummaryWriter(opt.log_path)

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        first_epoch = checkpoint["epoch"] + 1
        model.module.load_state_dict(checkpoint["model_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        first_epoch = 0

    for epoch in range(first_epoch, opt.epochs):
        train(model, train_loader, epoch, writer, criterion, optimizer,
              scheduler, opt.amp)
        evaluate(model, test_loader, epoch, writer, encoder, opt.nms_threshold)

        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)
Example #18
0
    def fit(self, dataset, mode='fit', **kwargs):
        from sklearn.metrics import accuracy_score

        assert self.model is not None

        params = self.model.parameters()
        val_loader = None
        if 'refit' in mode:
            train_loader = DataLoader(dataset=dataset.train_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=NUM_WORKERS)
            if mode == 'refit_test':
                val_loader = DataLoader(dataset=dataset.test_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=NUM_WORKERS)
        else:
            if not dataset.subset_sampler_used:
                train_loader = DataLoader(dataset=dataset.train_dataset,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          num_workers=NUM_WORKERS)
                val_loader = DataLoader(dataset=dataset.val_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=NUM_WORKERS)
            else:
                train_loader = DataLoader(dataset=dataset.train_dataset,
                                          batch_size=self.batch_size,
                                          sampler=dataset.train_sampler,
                                          num_workers=NUM_WORKERS)
                val_loader = DataLoader(dataset=dataset.train_for_val_dataset,
                                        batch_size=self.batch_size,
                                        sampler=dataset.val_sampler,
                                        num_workers=NUM_WORKERS)

        if self.optimizer == 'SGD':
            optimizer = SGD(params=params,
                            lr=self.sgd_learning_rate,
                            momentum=self.sgd_momentum)
        elif self.optimizer == 'Adam':
            optimizer = Adam(params=params,
                             lr=self.adam_learning_rate,
                             betas=(self.beta1, 0.999))
        else:
            return ValueError("Optimizer %s not supported!" % self.optimizer)

        scheduler = MultiStepLR(
            optimizer,
            milestones=[int(self.max_epoch * 0.5),
                        int(self.max_epoch * 0.75)],
            gamma=self.lr_decay)
        loss_func = nn.CrossEntropyLoss()
        early_stop = EarlyStop(patience=5, mode='min')

        if self.load_path:
            checkpoint = torch.load(self.load_path)
            self.model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            self.cur_epoch_num = checkpoint['epoch_num']
            early_stop = checkpoint['early_stop']
            if early_stop.if_early_stop:
                print("Early stop!")
                self.optimizer_ = optimizer
                self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
                self.scheduler = scheduler
                self.early_stop = early_stop
                return self

        profile_iter = kwargs.get('profile_iter', None)
        profile_epoch = kwargs.get('profile_epoch', None)
        assert not (profile_iter and profile_epoch)

        if profile_epoch or profile_iter:  # Profile mode
            self.model.train()
            if profile_epoch:
                for epoch in range(int(profile_epoch)):
                    for i, data in enumerate(train_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        optimizer.zero_grad()
                        loss = loss_func(logits, batch_y.to(self.device))
                        loss.backward()
                        optimizer.step()
            else:
                num_iter = 0
                stop_flag = False
                for epoch in range(int(self.epoch_num)):
                    if stop_flag:
                        break
                    for i, data in enumerate(train_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        optimizer.zero_grad()
                        loss = loss_func(logits, batch_y.to(self.device))
                        loss.backward()
                        optimizer.step()
                        num_iter += 1
                        if num_iter > profile_iter:
                            stop_flag = True
                            break
            return self

        for epoch in range(int(self.cur_epoch_num),
                           int(self.cur_epoch_num) + int(self.epoch_num)):
            self.model.train()
            # print('Current learning rate: %.5f' % optimizer.state_dict()['param_groups'][0]['lr'])
            epoch_avg_loss = 0
            epoch_avg_acc = 0
            val_avg_loss = 0
            val_avg_acc = 0
            num_train_samples = 0
            num_val_samples = 0
            for i, data in enumerate(train_loader):
                batch_x, batch_y = data[0], data[1]
                num_train_samples += len(batch_x)
                masks = torch.Tensor(
                    np.array([[float(i != 0) for i in sample]
                              for sample in batch_x]))
                logits = self.model(batch_x.long().to(self.device),
                                    masks.to(self.device))
                loss = loss_func(logits, batch_y.to(self.device))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_avg_loss += loss.to('cpu').detach() * len(batch_x)
                prediction = np.argmax(logits.to('cpu').detach().numpy(),
                                       axis=-1)
                epoch_avg_acc += accuracy_score(
                    prediction,
                    batch_y.to('cpu').detach().numpy()) * len(batch_x)

            epoch_avg_loss /= num_train_samples
            epoch_avg_acc /= num_train_samples
            # TODO: logger
            print('Epoch %d: Train loss %.4f, train acc %.4f' %
                  (epoch, epoch_avg_loss, epoch_avg_acc))

            if val_loader is not None:
                self.model.eval()
                with torch.no_grad():
                    for i, data in enumerate(val_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        val_loss = loss_func(logits, batch_y.to(self.device))
                        num_val_samples += len(batch_x)
                        val_avg_loss += val_loss.to('cpu').detach() * len(
                            batch_x)

                        prediction = np.argmax(
                            logits.to('cpu').detach().numpy(), axis=-1)
                        val_avg_acc += accuracy_score(
                            prediction,
                            batch_y.to('cpu').detach().numpy()) * len(batch_x)

                    val_avg_loss /= num_val_samples
                    val_avg_acc /= num_val_samples
                    print('Epoch %d: Val loss %.4f, val acc %.4f' %
                          (epoch, val_avg_loss, val_avg_acc))

                    # Early stop
                    if 'refit' not in mode:
                        early_stop.update(val_avg_loss)
                        if early_stop.if_early_stop:
                            self.early_stop_flag = True
                            print("Early stop!")
                            break

        scheduler.step()

        self.optimizer_ = optimizer
        self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
        self.scheduler = scheduler

        return self
Example #19
0
def main(meta_dir: str,
         save_dir: str,
         save_prefix: str,
         pretrained_path: str = '',
         batch_size: int = 32,
         num_workers: int = 8,
         lr: float = 1e-4,
         betas: Tuple[float, float] = (0.5, 0.9),
         weight_decay: float = 0.0,
         pretrain_step: int = 200000,
         max_step: int = 1000000,
         save_interval: int = 10000,
         log_scala_interval: int = 20,
         log_heavy_interval: int = 1000,
         gamma: float = 0.5,
         seed: int = 1234):
    #
    # prepare training
    #
    # create model
    mb_generator = build_model('generator_mb').cuda()
    discriminator = build_model('discriminator_base').cuda()

    # Multi-gpu is not required.

    # create optimizers
    mb_opt = torch.optim.Adam(mb_generator.parameters(),
                              lr=lr,
                              betas=betas,
                              weight_decay=weight_decay)
    dis_opt = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=betas,
                               weight_decay=weight_decay)

    # make scheduler
    mb_scheduler = MultiStepLR(mb_opt,
                               list(range(300000, 900000 + 1, 100000)),
                               gamma=gamma)
    dis_scheduler = MultiStepLR(dis_opt,
                                list(range(100000, 700000 + 1, 100000)),
                                gamma=gamma)

    # get datasets
    train_loader, valid_loader = get_datasets(meta_dir,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              crop_length=settings.SAMPLE_RATE,
                                              random_seed=seed)

    # repeat
    train_loader = repeat(train_loader)

    # build mel function
    mel_func, stft_funcs_for_loss = build_stft_functions()

    # build pqmf
    pqmf_func = PQMF().cuda()

    # prepare logging
    writer, model_dir = prepare_logging(save_dir, save_prefix)

    # Training Saving Attributes
    best_loss = np.finfo(np.float32).max
    initial_step = 0

    # load model
    if pretrained_path:
        log(f'Pretrained path is given : {pretrained_path} . Loading...')
        chk = torch.load(pretrained_path)
        gen_chk, dis_chk = chk['generator'], chk['discriminator']
        gen_opt_chk, dis_opt_chk = chk['gen_opt'], chk['dis_opt']
        initial_step = int(chk['step'])
        l = chk['loss']

        mb_generator.load_state_dict(gen_chk)
        discriminator.load_state_dict(dis_chk)
        mb_opt.load_state_dict(gen_opt_chk)
        dis_opt.load_state_dict(dis_opt_chk)
        if 'dis_scheduler' in chk:
            dis_scheduler_chk = chk['dis_scheduler']
            gen_scheduler_chk = chk['gen_scheduler']
            mb_scheduler.load_state_dict(gen_scheduler_chk)
            dis_scheduler.load_state_dict(dis_scheduler_chk)

        mb_opt._step_count = initial_step
        mb_scheduler._step_count = initial_step
        dis_opt._step_count = initial_step - pretrain_step
        dis_scheduler._step_count = initial_step - pretrain_step

        mb_scheduler.step(initial_step)
        dis_scheduler.step(initial_step - pretrain_step)
        best_loss = l

    #
    # Training !
    #
    # Pretraining generator
    for step in range(initial_step, pretrain_step):
        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        # get multi-resolution stft loss   eq 9)
        loss, mb_loss, fb_loss = get_stft_loss(pred, wav, pred_subbands,
                                               target_subbands,
                                               stft_funcs_for_loss)

        # backward and update
        loss.backward()
        mb_opt.step()
        mb_scheduler.step()

        mb_opt.zero_grad()
        mb_generator.zero_grad()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/pretrain_loss',
                              loss.item(),
                              global_step=step)
            writer.add_scalar('train/mb_loss',
                              mb_loss.item(),
                              global_step=step)
            writer.add_scalar('train/fb_loss',
                              fb_loss.item(),
                              global_step=step)

            if step % log_heavy_interval == 0:
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss: {loss.item()} / mb_loss: {mb_loss.item()} / fb_loss: {fb_loss.item()}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_loss = 0.
            valid_mb_loss, valid_fb_loss = 0., 0.
            count = 0
            mb_generator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # forward
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # get stft loss
                    loss, mb_loss, fb_loss = get_stft_loss(
                        pred, wav, pred_subbands, target_subbands,
                        stft_funcs_for_loss)

                valid_loss += loss.item()
                valid_mb_loss += mb_loss.item()
                valid_fb_loss += fb_loss.item()
                count = idx

            valid_loss /= (count + 1)
            valid_mb_loss /= (count + 1)
            valid_fb_loss /= (count + 1)
            mb_generator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('valid/pretrain_loss',
                              valid_loss,
                              global_step=step)
            writer.add_scalar('valid/mb_loss', valid_mb_loss, global_step=step)
            writer.add_scalar('valid/fb_loss', valid_fb_loss, global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- Valid loss: {valid_loss} / mb_loss: {valid_mb_loss} / fb_loss: {valid_fb_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_loss < best_loss
            if is_best:
                best_loss = valid_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_loss,
                            is_best=is_best)

    #
    # Train GAN
    #
    dis_block_layers = 6
    lambda_gen = 2.5
    best_loss = np.finfo(np.float32).max

    for step in range(max(pretrain_step, initial_step), max_step):

        # data
        wav, _ = next(train_loader)
        wav = wav.cuda()

        # to mel
        mel = mel_func(wav)

        # pqmf
        target_subbands = pqmf_func.analysis(wav.unsqueeze(1))  # N, SUBBAND, T

        #
        # Train Discriminator
        #

        # forward
        pred_subbands = mb_generator(mel)
        pred_subbands, _ = match_dim(pred_subbands, target_subbands)

        # pqmf synthesis
        pred = pqmf_func.synthesis(pred_subbands)
        pred, wav = match_dim(pred, wav)

        with torch.no_grad():
            pred_mel = mel_func(pred.squeeze(1).detach())
            mel_err = F.l1_loss(mel, pred_mel).item()

        # if terminate_step > step:
        d_fake_det = discriminator(pred.detach())
        d_real = discriminator(wav.unsqueeze(1))

        # calculate discriminator losses  eq 1)
        loss_D = 0

        for idx in range(dis_block_layers - 1, len(d_fake_det),
                         dis_block_layers):
            loss_D += torch.mean((d_fake_det[idx] - 1)**2)

        for idx in range(dis_block_layers - 1, len(d_real), dis_block_layers):
            loss_D += torch.mean(d_real[idx]**2)

        # train
        discriminator.zero_grad()
        loss_D.backward()
        dis_opt.step()
        dis_scheduler.step()

        #
        # Train Generator
        #
        d_fake = discriminator(pred)

        # calc generator loss   eq 8)
        loss_G = 0
        for idx in range(dis_block_layers - 1, len(d_fake), dis_block_layers):
            loss_G += ((d_fake[idx] - 1)**2).mean()

        loss_G *= lambda_gen

        # get multi-resolution stft loss
        loss_G += get_stft_loss(pred, wav, pred_subbands, target_subbands,
                                stft_funcs_for_loss)[0]
        # loss_G += get_spec_losses(pred, wav, stft_funcs_for_loss)[0]

        mb_generator.zero_grad()
        loss_G.backward()
        mb_opt.step()
        mb_scheduler.step()

        #
        # logging! save!
        #
        if step % log_scala_interval == 0 and step > 0:
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            writer.add_scalar('train/loss_G', loss_G.item(), global_step=step)
            writer.add_scalar('train/loss_D', loss_D.item(), global_step=step)
            writer.add_scalar('train/mel_err', mel_err, global_step=step)
            if step % log_heavy_interval == 0:
                target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
                pred_mel = imshow_to_buf(
                    mel_func(pred[:1, 0])[0].detach().cpu().numpy())

                writer.add_image('train/target_mel',
                                 target_mel,
                                 global_step=step)
                writer.add_image('train/pred_mel', pred_mel, global_step=step)
                writer.add_audio('train/pred_audio',
                                 pred_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)
                writer.add_audio('train/target_audio',
                                 target_audio,
                                 sample_rate=settings.SAMPLE_RATE,
                                 global_step=step)

            # console
            msg = f'train: step: {step} / loss_G: {loss_G.item()} / loss_D: {loss_D.item()} / ' \
                f' mel_err: {mel_err}'
            log(msg)

        if step % save_interval == 0 and step > 0:
            #
            # Validation Step !
            #
            valid_g_loss, valid_d_loss, valid_mel_loss = 0., 0., 0.
            count = 0
            mb_generator.eval()
            discriminator.eval()

            for idx, (wav, _) in enumerate(valid_loader):
                # setup data
                wav = wav.cuda()
                mel = mel_func(wav)

                with torch.no_grad():
                    # pqmf
                    target_subbands = pqmf_func.analysis(
                        wav.unsqueeze(1))  # N, SUBBAND, T

                    # Discriminator
                    pred_subbands = mb_generator(mel)
                    pred_subbands, _ = match_dim(pred_subbands,
                                                 target_subbands)

                    # pqmf synthesis
                    pred = pqmf_func.synthesis(pred_subbands)
                    pred, wav = match_dim(pred, wav)

                    # Mel Error
                    pred_mel = mel_func(pred.squeeze(1).detach())
                    mel_err = F.l1_loss(mel, pred_mel).item()

                    #
                    # discriminator part
                    #
                    d_fake_det = discriminator(pred.detach())
                    d_real = discriminator(wav.unsqueeze(1))

                    loss_D = 0

                    for idx in range(dis_block_layers - 1, len(d_fake_det),
                                     dis_block_layers):
                        loss_D += torch.mean((d_fake_det[idx] - 1)**2)

                    for idx in range(dis_block_layers - 1, len(d_real),
                                     dis_block_layers):
                        loss_D += torch.mean(d_real[idx]**2)

                    #
                    # generator part
                    #
                    d_fake = discriminator(pred)

                    # calc generator loss
                    loss_G = 0
                    for idx in range(dis_block_layers - 1, len(d_fake),
                                     dis_block_layers):
                        loss_G += ((d_fake[idx] - 1)**2).mean()

                    loss_G *= lambda_gen

                    # get stft loss
                    stft_loss = get_stft_loss(pred, wav, pred_subbands,
                                              target_subbands,
                                              stft_funcs_for_loss)[0]
                    loss_G += stft_loss

                valid_d_loss += loss_D.item()
                valid_g_loss += loss_G.item()
                valid_mel_loss += mel_err
                count = idx

            valid_d_loss /= (count + 1)
            valid_g_loss /= (count + 1)
            valid_mel_loss /= (count + 1)

            mb_generator.train()
            discriminator.train()

            # log validation
            # log writer
            pred_audio = pred[0, 0]
            target_audio = wav[0]
            target_mel = imshow_to_buf(mel[0].detach().cpu().numpy())
            pred_mel = imshow_to_buf(
                mel_func(pred[:1, 0])[0].detach().cpu().numpy())

            writer.add_image('valid/target_mel', target_mel, global_step=step)
            writer.add_image('valid/pred_mel', pred_mel, global_step=step)
            writer.add_scalar('valid/loss_G', valid_g_loss, global_step=step)
            writer.add_scalar('valid/loss_D', valid_d_loss, global_step=step)
            writer.add_scalar('valid/mel_err',
                              valid_mel_loss,
                              global_step=step)
            writer.add_audio('valid/pred_audio',
                             pred_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)
            writer.add_audio('valid/target_audio',
                             target_audio,
                             sample_rate=settings.SAMPLE_RATE,
                             global_step=step)

            # console
            log(f'---- loss_G: {valid_g_loss} / loss_D: {valid_d_loss} / mel loss : {valid_mel_loss} ----'
                )

            #
            # save checkpoint
            #
            is_best = valid_g_loss < best_loss
            if is_best:
                best_loss = valid_g_loss
            save_checkpoint(mb_generator,
                            discriminator,
                            mb_opt,
                            dis_opt,
                            mb_scheduler,
                            dis_scheduler,
                            model_dir,
                            step,
                            valid_g_loss,
                            is_best=is_best)

    log('----- Finish ! -----')
Example #20
0
def main():

    # 1. argparser
    opts = parse(sys.argv[1:])
    print(opts)

    # 3. visdom
    vis = visdom.Visdom(port=opts.port)
    # 4. data set
    train_set = None
    test_set = None

    if opts.data_type == 'voc':
        train_set = VOC_Dataset(root=opts.data_root, split='train', resize=opts.resize)
        test_set = VOC_Dataset(root=opts.data_root, split='test', resize=opts.resize)
        opts.num_classes = 20

    elif opts.data_type == 'coco':
        train_set = COCO_Dataset(root=opts.data_root, set_name='train2017', split='train', resize=opts.resize)
        test_set = COCO_Dataset(root=opts.data_root, set_name='val2017', split='test', resize=opts.resize)
        opts.num_classes = 80

    # 5. data loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opts.batch_size,
                                               collate_fn=train_set.collate_fn,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              collate_fn=test_set.collate_fn,
                                              shuffle=False,
                                              num_workers=2,
                                              pin_memory=True)

    # 6. network
    model = RetinaNet(num_classes=opts.num_classes).to(device)
    model = torch.nn.DataParallel(module=model, device_ids=device_ids)
    coder = RETINA_Coder(opts=opts)  # there is center_anchor in coder.

    # 7. loss
    criterion = Focal_Loss(coder=coder)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=opts.momentum,
                                weight_decay=opts.weight_decay)

    # 9. scheduler
    scheduler = MultiStepLR(optimizer=optimizer, milestones=[30, 45], gamma=0.1)

    # 10. resume
    if opts.start_epoch != 0:

        checkpoint = torch.load(os.path.join(opts.save_path, opts.save_file_name) + '.{}.pth.tar'
                                .format(opts.start_epoch - 1), map_location=device)        # 하나 적은걸 가져와서 train
        model.load_state_dict(checkpoint['model_state_dict'])                              # load model state dict
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])                      # load optim state dict
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])                      # load sched state dict
        print('\nLoaded checkpoint from epoch %d.\n' % (int(opts.start_epoch) - 1))

    else:

        print('\nNo check point to resume.. train from scratch.\n')

    # for statement
    for epoch in range(opts.start_epoch, opts.epoch):

        # 11. train
        train(epoch=epoch,
              vis=vis,
              train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              scheduler=scheduler,
              opts=opts)

        # 12. test
        test(epoch=epoch,
             vis=vis,
             test_loader=test_loader,
             model=model,
             criterion=criterion,
             coder=coder,
             opts=opts)

        scheduler.step()
Example #21
0
def main_train(args):
    # 获取命令参数
    if args.resume_training is not None:
        if not os.path.isfile(args.resume_training):
            print(f"{args.resume_training} 不是一个合法的文件!")
            return
        else:
            print(f"加载检查点:{args.resume_training}")
    cuda = args.cuda
    resume = args.resume_training
    batch_size = args.batch_size
    milestones = args.milestones
    lr = args.lr
    total_epoch = args.epochs
    resume_checkpoint_filename = args.resume_training
    best_model_name = args.best_model_name
    checkpoint_name = args.best_model_name
    data_path = args.data_path
    start_epoch = 1

    print("加载数据....")
    dataset = ISONetData(data_path=data_path)
    dataset_test = ISONetData(data_path=data_path, train=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=6,
                             pin_memory=True)
    data_loader_test = DataLoader(dataset=dataset_test,
                                  batch_size=batch_size,
                                  shuffle=False)
    print("成功加载数据...")
    print(f"训练集数量: {len(dataset)}")
    print(f"验证集数量: {len(dataset_test)}")

    model_path = Path("models")
    checkpoint_path = model_path.joinpath("checkpoint")

    if not model_path.exists():
        model_path.mkdir()
    if not checkpoint_path.exists():
        checkpoint_path.mkdir()

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
    else:
        print("cuda 无效!")
        cuda = False

    net = ISONet()
    criterion = nn.MSELoss(reduction="mean")
    optimizer = optim.Adam(net.parameters(), lr=lr)

    if cuda:
        net = net.to(device=device)
        criterion = criterion.to(device=device)

    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=milestones,
                            gamma=0.1)
    writer = SummaryWriter()

    # 恢复训练
    if resume:
        print("恢复训练中...")
        checkpoint = torch.load(
            checkpoint_path.joinpath(resume_checkpoint_filename))
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict((checkpoint["optimizer"]))
        scheduler.load_state_dict(checkpoint["scheduler"])
        resume_epoch = checkpoint["epoch"]
        best_test_loss = checkpoint["best_test_loss"]

        start_epoch = resume_epoch + 1
        print(f"从第[{start_epoch}]轮开始训练...")
        print(f"上一次的损失为: [{best_test_loss}]...")
    else:
        # 初始化权重
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    if not locals().get("best_test_loss"):
        best_test_loss = 0

    record = 0
    for epoch in range(start_epoch, total_epoch):
        print(f"开始第 [{epoch}] 轮训练...")
        net.train()
        writer.add_scalar("Train/Learning Rate",
                          scheduler.get_last_lr()[0], epoch)
        for i, (data, label) in enumerate(data_loader, 0):
            if i == 0:
                start_time = int(time.time())
            if cuda:
                data = data.to(device=device)
                label = label.to(device=device)
            label = label.unsqueeze(1)

            optimizer.zero_grad()

            output = net(data)

            loss = criterion(output, label)

            loss.backward()

            optimizer.step()
            if i % 500 == 499:
                end_time = int(time.time())
                use_time = end_time - start_time

                print(
                    f">>> epoch[{epoch}] loss[{loss:.4f}]  {i * batch_size}/{len(dataset)} lr{scheduler.get_last_lr()} ",
                    end="")
                left_time = ((len(dataset) - i * batch_size) / 500 /
                             batch_size) * (end_time - start_time)
                print(
                    f"耗费时间:[{end_time - start_time:.2f}]秒,估计剩余时间: [{left_time:.2f}]秒"
                )
                start_time = end_time
            # 记录到 tensorboard
            if i % 128 == 127:
                writer.add_scalar("Train/loss", loss, record)
                record += 1

        # validate
        print("测试模型...")
        net.eval()

        test_loss = 0
        with torch.no_grad():
            loss_t = nn.MSELoss(reduction="mean")
            if cuda:
                loss_t = loss_t.to(device)
            for data, label in data_loader_test:
                if cuda:
                    data = data.to(device)
                    label = label.to(device)
                # expand dim
                label = label.unsqueeze_(1)
                predict = net(data)
                # sum up batch loss
                test_loss += loss_t(predict, label).item()

        test_loss /= len(dataset_test)
        test_loss *= batch_size
        print(
            f'\nTest Data: Average batch[{batch_size}] loss: {test_loss:.4f}\n'
        )
        scheduler.step()

        writer.add_scalar("Test/Loss", test_loss, epoch)

        checkpoint = {
            "net": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "scheduler": scheduler.state_dict(),
            "best_test_loss": best_test_loss
        }

        if best_test_loss == 0:
            print("保存模型中...")
            torch.save(net.state_dict(), model_path.joinpath(best_model_name))
            best_test_loss = test_loss
        else:
            # 保存更好的模型
            if test_loss < best_test_loss:
                print("获取到更好的模型,保存中...")
                torch.save(net.state_dict(),
                           model_path.joinpath(best_model_name))
                best_test_loss = test_loss
        # 保存检查点
        if epoch % args.save_every_epochs == 0:
            c_time = time2str()
            torch.save(
                checkpoint,
                checkpoint_path.joinpath(
                    f"{checkpoint_name}_{epoch}_{c_time}.cpth"))
            print(f"保存检查点: [{checkpoint_name}_{epoch}_{c_time}.cpth]...\n")
Example #22
0
class Processor():
    """Processor for Skeleton-based Action Recgnition"""
    def __init__(self, arg):
        self.arg = arg
        self.save_arg()
        if arg.phase == 'train':
            # Added control through the command line
            arg.train_feeder_args[
                'debug'] = arg.train_feeder_args['debug'] or self.arg.debug
            logdir = os.path.join(arg.work_dir, 'trainlogs')
            if not arg.train_feeder_args['debug']:
                # logdir = arg.model_saved_name
                if os.path.isdir(logdir):
                    print(f'log_dir {logdir} already exists')
                    if arg.assume_yes:
                        answer = 'y'
                    else:
                        answer = input('delete it? [y]/n:')
                    if answer.lower() in ('y', ''):
                        shutil.rmtree(logdir)
                        print('Dir removed:', logdir)
                    else:
                        print('Dir not removed:', logdir)

                self.train_writer = SummaryWriter(
                    os.path.join(logdir, 'train'), 'train')
                self.val_writer = SummaryWriter(os.path.join(logdir, 'val'),
                                                'val')
            else:
                self.train_writer = SummaryWriter(
                    os.path.join(logdir, 'debug'), 'debug')

        self.load_model()
        self.load_param_groups()
        self.load_optimizer()
        self.load_lr_scheduler()
        self.load_data()

        self.global_step = 0
        self.lr = self.arg.base_lr
        self.best_acc = 0
        self.best_acc_epoch = 0

        if self.arg.half:
            self.print_log('*************************************')
            self.print_log('*** Using Half Precision Training ***')
            self.print_log('*************************************')
            self.model, self.optimizer = apex.amp.initialize(
                self.model,
                self.optimizer,
                opt_level=f'O{self.arg.amp_opt_level}')
            if self.arg.amp_opt_level != 1:
                self.print_log(
                    '[WARN] nn.DataParallel is not yet supported by amp_opt_level != "O1"'
                )

        if type(self.arg.device) is list:
            if len(self.arg.device) > 1:
                self.print_log(
                    f'{len(self.arg.device)} GPUs available, using DataParallel'
                )
                self.model = nn.DataParallel(self.model,
                                             device_ids=self.arg.device,
                                             output_device=self.output_device)

    def load_model(self):
        output_device = self.arg.device[0] if type(
            self.arg.device) is list else self.arg.device
        self.output_device = output_device
        Model = import_class(self.arg.model)

        # Copy model file and main
        shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
        shutil.copy2(os.path.join('.', __file__), self.arg.work_dir)

        self.model = Model(**self.arg.model_args).cuda(output_device)
        self.loss = nn.CrossEntropyLoss().cuda(output_device)
        self.print_log(
            f'Model total number of params: {count_params(self.model)}')

        if self.arg.weights:
            try:
                self.global_step = int(arg.weights[:-3].split('-')[-1])
            except:
                print('Cannot parse global_step from model weights filename')
                self.global_step = 0

            self.print_log(f'Loading weights from {self.arg.weights}')
            if '.pkl' in self.arg.weights:
                with open(self.arg.weights, 'r') as f:
                    weights = pickle.load(f)
            else:
                weights = torch.load(self.arg.weights)

            weights = OrderedDict(
                [[k.split('module.')[-1],
                  v.cuda(output_device)] for k, v in weights.items()])

            for w in self.arg.ignore_weights:
                if weights.pop(w, None) is not None:
                    self.print_log(f'Sucessfully Remove Weights: {w}')
                else:
                    self.print_log(f'Can Not Remove Weights: {w}')

            try:
                self.model.load_state_dict(weights)
            except:
                state = self.model.state_dict()
                diff = list(set(state.keys()).difference(set(weights.keys())))
                self.print_log('Can not find these weights:')
                for d in diff:
                    self.print_log('  ' + d)
                state.update(weights)
                self.model.load_state_dict(state)

    def load_param_groups(self):
        """
        Template function for setting different learning behaviour
        (e.g. LR, weight decay) of different groups of parameters
        """
        self.param_groups = defaultdict(list)

        for name, params in self.model.named_parameters():
            self.param_groups['other'].append(params)

        self.optim_param_groups = {
            'other': {
                'params': self.param_groups['other']
            }
        }

    def load_optimizer(self):
        params = list(self.optim_param_groups.values())
        if self.arg.optimizer == 'SGD':
            self.optimizer = optim.SGD(params,
                                       lr=self.arg.base_lr,
                                       momentum=0.9,
                                       nesterov=self.arg.nesterov,
                                       weight_decay=self.arg.weight_decay)
        elif self.arg.optimizer == 'Adam':
            self.optimizer = optim.Adam(params,
                                        lr=self.arg.base_lr,
                                        weight_decay=self.arg.weight_decay)
        else:
            raise ValueError('Unsupported optimizer: {}'.format(
                self.arg.optimizer))

        # Load optimizer states if any
        if self.arg.checkpoint is not None:
            self.print_log(
                f'Loading optimizer states from: {self.arg.checkpoint}')
            self.optimizer.load_state_dict(
                torch.load(self.arg.checkpoint)['optimizer_states'])
            current_lr = self.optimizer.param_groups[0]['lr']
            self.print_log(f'Starting LR: {current_lr}')
            self.print_log(
                f'Starting WD1: {self.optimizer.param_groups[0]["weight_decay"]}'
            )
            if len(self.optimizer.param_groups) >= 2:
                self.print_log(
                    f'Starting WD2: {self.optimizer.param_groups[1]["weight_decay"]}'
                )

    def load_lr_scheduler(self):
        self.lr_scheduler = MultiStepLR(self.optimizer,
                                        milestones=self.arg.step,
                                        gamma=0.1)
        if self.arg.checkpoint is not None:
            scheduler_states = torch.load(
                self.arg.checkpoint)['lr_scheduler_states']
            self.print_log(
                f'Loading LR scheduler states from: {self.arg.checkpoint}')
            self.lr_scheduler.load_state_dict(scheduler_states)
            self.print_log(
                f'Starting last epoch: {scheduler_states["last_epoch"]}')
            self.print_log(
                f'Loaded milestones: {scheduler_states["last_epoch"]}')

    def load_data(self):
        Feeder = import_class(self.arg.feeder)
        self.data_loader = dict()

        def worker_seed_fn(worker_id):
            # give workers different seeds
            return init_seed(self.arg.seed + worker_id + 1)

        if self.arg.phase == 'train':
            self.data_loader['train'] = torch.utils.data.DataLoader(
                dataset=Feeder(**self.arg.train_feeder_args),
                batch_size=self.arg.batch_size,
                shuffle=True,
                num_workers=self.arg.num_worker,
                drop_last=True,
                worker_init_fn=worker_seed_fn)

        self.data_loader['test'] = torch.utils.data.DataLoader(
            dataset=Feeder(**self.arg.test_feeder_args),
            batch_size=self.arg.test_batch_size,
            shuffle=False,
            num_workers=self.arg.num_worker,
            drop_last=False,
            worker_init_fn=worker_seed_fn)

    def save_arg(self):
        # save arg
        arg_dict = vars(self.arg)
        if not os.path.exists(self.arg.work_dir):
            os.makedirs(self.arg.work_dir)
        with open(os.path.join(self.arg.work_dir, 'config.yaml'), 'w') as f:
            yaml.dump(arg_dict, f)

    def print_time(self):
        localtime = time.asctime(time.localtime(time.time()))
        self.print_log(f'Local current time: {localtime}')

    def print_log(self, s, print_time=True):
        if print_time:
            localtime = time.asctime(time.localtime(time.time()))
            s = f'[ {localtime} ] {s}'
        print(s)
        if self.arg.print_log:
            with open(os.path.join(self.arg.work_dir, 'log.txt'), 'a') as f:
                print(s, file=f)

    def record_time(self):
        self.cur_time = time.time()
        return self.cur_time

    def split_time(self):
        split_time = time.time() - self.cur_time
        self.record_time()
        return split_time

    def save_states(self, epoch, states, out_folder, out_name):
        out_folder_path = os.path.join(self.arg.work_dir, out_folder)
        out_path = os.path.join(out_folder_path, out_name)
        os.makedirs(out_folder_path, exist_ok=True)
        torch.save(states, out_path)

    def save_checkpoint(self, epoch, out_folder='checkpoints'):
        state_dict = {
            'epoch': epoch,
            'optimizer_states': self.optimizer.state_dict(),
            'lr_scheduler_states': self.lr_scheduler.state_dict(),
        }

        checkpoint_name = f'checkpoint-{epoch}-fwbz{self.arg.forward_batch_size}-{int(self.global_step)}.pt'
        self.save_states(epoch, state_dict, out_folder, checkpoint_name)

    def save_weights(self, epoch, out_folder='weights'):
        state_dict = self.model.state_dict()
        weights = OrderedDict([[k.split('module.')[-1],
                                v.cpu()] for k, v in state_dict.items()])

        weights_name = f'weights-{epoch}-{int(self.global_step)}.pt'
        self.save_states(epoch, weights, out_folder, weights_name)

    def train(self, epoch, save_model=False):
        self.model.train()
        loader = self.data_loader['train']
        loss_values = []
        self.train_writer.add_scalar('epoch', epoch + 1, self.global_step)
        self.record_time()
        timer = dict(dataloader=0.001, model=0.001, statistics=0.001)

        current_lr = self.optimizer.param_groups[0]['lr']
        self.print_log(f'Training epoch: {epoch + 1}, LR: {current_lr:.4f}')

        process = tqdm(loader, dynamic_ncols=True)
        for batch_idx, (data, label, index) in enumerate(process):
            self.global_step += 1
            # get data
            with torch.no_grad():
                data = data.float().cuda(self.output_device)
                label = label.long().cuda(self.output_device)
            timer['dataloader'] += self.split_time()

            # backward
            self.optimizer.zero_grad()

            ############## Gradient Accumulation for Smaller Batches ##############
            real_batch_size = self.arg.forward_batch_size
            splits = len(data) // real_batch_size
            assert len(data) % real_batch_size == 0, \
                'Real batch size should be a factor of arg.batch_size!'

            for i in range(splits):
                left = i * real_batch_size
                right = left + real_batch_size
                batch_data, batch_label = data[left:right], label[left:right]

                # forward
                output = self.model(batch_data)
                if isinstance(output, tuple):
                    output, l1 = output
                    l1 = l1.mean()
                else:
                    l1 = 0

                loss = self.loss(output, batch_label) / splits

                if self.arg.half:
                    with apex.amp.scale_loss(loss,
                                             self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss_values.append(loss.item())
                timer['model'] += self.split_time()

                # Display loss
                process.set_description(
                    f'(BS {real_batch_size}) loss: {loss.item():.4f}')

                value, predict_label = torch.max(output, 1)
                acc = torch.mean((predict_label == batch_label).float())

                self.train_writer.add_scalar('acc', acc, self.global_step)
                self.train_writer.add_scalar('loss',
                                             loss.item() * splits,
                                             self.global_step)
                self.train_writer.add_scalar('loss_l1', l1, self.global_step)

            #####################################

            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 2)
            self.optimizer.step()

            # statistics
            self.lr = self.optimizer.param_groups[0]['lr']
            self.train_writer.add_scalar('lr', self.lr, self.global_step)
            timer['statistics'] += self.split_time()

            # Delete output/loss after each batch since it may introduce extra mem during scoping
            # https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770/3
            del output
            del loss

        # statistics of time consumption and loss
        proportion = {
            k: f'{int(round(v * 100 / sum(timer.values()))):02d}%'
            for k, v in timer.items()
        }

        mean_loss = np.mean(loss_values)
        num_splits = self.arg.batch_size // self.arg.forward_batch_size
        self.print_log(
            f'\tMean training loss: {mean_loss:.4f} (BS {self.arg.batch_size}: {mean_loss * num_splits:.4f}).'
        )
        self.print_log(
            '\tTime consumption: [Data]{dataloader}, [Network]{model}'.format(
                **proportion))

        # PyTorch > 1.2.0: update LR scheduler here with `.step()`
        # and make sure to save the `lr_scheduler.state_dict()` as part of checkpoint
        self.lr_scheduler.step()

        if save_model:
            # save training checkpoint & weights
            self.save_weights(epoch + 1)
            self.save_checkpoint(epoch + 1)

    def eval(self,
             epoch,
             save_score=False,
             loader_name=['test'],
             wrong_file=None,
             result_file=None):
        # Skip evaluation if too early
        if epoch + 1 < self.arg.eval_start:
            return

        if wrong_file is not None:
            f_w = open(wrong_file, 'w')
        if result_file is not None:
            f_r = open(result_file, 'w')
        with torch.no_grad():
            self.model = self.model.cuda(self.output_device)
            self.model.eval()
            self.print_log(f'Eval epoch: {epoch + 1}')
            for ln in loader_name:
                loss_values = []
                score_batches = []
                step = 0
                process = tqdm(self.data_loader[ln], dynamic_ncols=True)
                for batch_idx, (data, label, index) in enumerate(process):
                    data = data.float().cuda(self.output_device)
                    label = label.long().cuda(self.output_device)
                    output = self.model(data)
                    if isinstance(output, tuple):
                        output, l1 = output
                        l1 = l1.mean()
                    else:
                        l1 = 0
                    loss = self.loss(output, label)
                    score_batches.append(output.data.cpu().numpy())
                    loss_values.append(loss.item())

                    _, predict_label = torch.max(output.data, 1)
                    step += 1

                    if wrong_file is not None or result_file is not None:
                        predict = list(predict_label.cpu().numpy())
                        true = list(label.data.cpu().numpy())
                        for i, x in enumerate(predict):
                            if result_file is not None:
                                f_r.write(str(x) + ',' + str(true[i]) + '\n')
                            if x != true[i] and wrong_file is not None:
                                f_w.write(
                                    str(index[i]) + ',' + str(x) + ',' +
                                    str(true[i]) + '\n')

            score = np.concatenate(score_batches)
            loss = np.mean(loss_values)
            accuracy = self.data_loader[ln].dataset.top_k(score, 1)
            if accuracy > self.best_acc:
                self.best_acc = accuracy
                self.best_acc_epoch = epoch + 1

            print('Accuracy: ', accuracy, ' model: ', self.arg.work_dir)
            if self.arg.phase == 'train' and not self.arg.debug:
                self.val_writer.add_scalar('loss', loss, self.global_step)
                self.val_writer.add_scalar('loss_l1', l1, self.global_step)
                self.val_writer.add_scalar('acc', accuracy, self.global_step)

            score_dict = dict(
                zip(self.data_loader[ln].dataset.sample_name, score))
            self.print_log(
                f'\tMean {ln} loss of {len(self.data_loader[ln])} batches: {np.mean(loss_values)}.'
            )
            for k in self.arg.show_topk:
                self.print_log(
                    f'\tTop {k}: {100 * self.data_loader[ln].dataset.top_k(score, k):.2f}%'
                )

            if save_score:
                with open(
                        '{}/epoch{}_{}_score.pkl'.format(
                            self.arg.work_dir, epoch + 1, ln), 'wb') as f:
                    pickle.dump(score_dict, f)

        # Empty cache after evaluation
        torch.cuda.empty_cache()

    def start(self):
        if self.arg.phase == 'train':
            self.print_log(f'Parameters:\n{pprint.pformat(vars(self.arg))}\n')
            self.print_log(
                f'Model total number of params: {count_params(self.model)}')
            self.global_step = self.arg.start_epoch * len(
                self.data_loader['train']) / self.arg.batch_size
            for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
                save_model = ((epoch + 1) % self.arg.save_interval
                              == 0) or (epoch + 1 == self.arg.num_epoch)
                self.train(epoch, save_model=save_model)
                self.eval(epoch,
                          save_score=self.arg.save_score,
                          loader_name=['test'])

            num_params = sum(p.numel() for p in self.model.parameters()
                             if p.requires_grad)
            self.print_log(f'Best accuracy: {self.best_acc}')
            self.print_log(f'Epoch number: {self.best_acc_epoch}')
            self.print_log(f'Model name: {self.arg.work_dir}')
            self.print_log(f'Model total number of params: {num_params}')
            self.print_log(f'Weight decay: {self.arg.weight_decay}')
            self.print_log(f'Base LR: {self.arg.base_lr}')
            self.print_log(f'Batch Size: {self.arg.batch_size}')
            self.print_log(
                f'Forward Batch Size: {self.arg.forward_batch_size}')
            self.print_log(f'Test Batch Size: {self.arg.test_batch_size}')

        elif self.arg.phase == 'test':
            if not self.arg.test_feeder_args['debug']:
                wf = os.path.join(self.arg.work_dir, 'wrong-samples.txt')
                rf = os.path.join(self.arg.work_dir, 'right-samples.txt')
            else:
                wf = rf = None
            if self.arg.weights is None:
                raise ValueError('Please appoint --weights.')

            self.print_log(f'Model:   {self.arg.model}')
            self.print_log(f'Weights: {self.arg.weights}')

            self.eval(epoch=0,
                      save_score=self.arg.save_score,
                      loader_name=['test'],
                      wrong_file=wf,
                      result_file=rf)

            self.print_log('Done.\n')
Example #23
0
    def fit(self, dataset: DLDataset, mode='fit', **kwargs):
        assert self.model is not None

        if self.load_path:
            self.model.load_state_dict(torch.load(self.load_path))

        params = self.model.parameters()

        val_loader = None
        if 'refit' in mode:
            train_loader = DataLoader(
                dataset=dataset.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=NUM_WORKERS,
                collate_fn=dataset.train_dataset.collate_fn)
            if mode == 'refit_test':
                val_loader = DataLoader(
                    dataset=dataset.test_dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=NUM_WORKERS,
                    collate_fn=dataset.test_dataset.collate_fn)
        else:
            train_loader = DataLoader(
                dataset=dataset.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=NUM_WORKERS,
                collate_fn=dataset.train_dataset.collate_fn)
            val_loader = DataLoader(dataset=dataset.val_dataset,
                                    batch_size=self.batch_size,
                                    shuffle=False,
                                    num_workers=NUM_WORKERS,
                                    collate_fn=dataset.val_dataset.collate_fn)
            # else:
            #     train_loader = DataLoader(dataset=dataset.train_dataset, batch_size=self.batch_size,
            #                               sampler=dataset.train_sampler, num_workers=4,
            #                               collate_fn=dataset.train_dataset.collate_fn)
            #     val_loader = DataLoader(dataset=dataset.train_dataset, batch_size=self.batch_size,
            #                             sampler=dataset.val_sampler, num_workers=4,
            #                             collate_fn=dataset.train_dataset.collate_fn)

        if self.optimizer == 'SGD':
            optimizer = SGD(params=params,
                            lr=self.sgd_learning_rate,
                            momentum=self.sgd_momentum)
        elif self.optimizer == 'Adam':
            optimizer = Adam(params=params,
                             lr=self.adam_learning_rate,
                             betas=(self.beta1, 0.999))
        else:
            return ValueError("Optimizer %s not supported!" % self.optimizer)

        scheduler = MultiStepLR(
            optimizer,
            milestones=[int(self.max_epoch * 0.5),
                        int(self.max_epoch * 0.75)],
            gamma=self.lr_decay)
        early_stop = EarlyStop(patience=5, mode='min')

        if self.load_path:
            checkpoint = torch.load(self.load_path)
            self.model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            self.cur_epoch_num = checkpoint['epoch_num']
            early_stop = checkpoint['early_stop']
            if early_stop.if_early_stop:
                print("Early stop!")
                self.optimizer_ = optimizer
                self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
                self.scheduler = scheduler
                self.early_stop = early_stop
                return self

        profile_iter = kwargs.get('profile_iter', None)
        profile_epoch = kwargs.get('profile_epoch', None)
        assert not (profile_iter and profile_epoch)

        if profile_epoch or profile_iter:  # Profile mode
            self.model.train()
            if profile_epoch:
                for epoch in range(int(profile_epoch)):
                    for i, (_, batch_x, batch_y) in enumerate(train_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
            else:
                num_iter = 0
                stop_flag = False
                for epoch in range(int(self.epoch_num)):
                    if stop_flag:
                        break
                    for i, (_, batch_x, batch_y) in enumerate(train_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        num_iter += 1
                        if num_iter > profile_iter:
                            stop_flag = True
                            break
            return self

        for epoch in range(int(self.cur_epoch_num),
                           int(self.cur_epoch_num) + int(self.epoch_num)):
            self.model.train()
            # print('Current learning rate: %.5f' % optimizer.state_dict()['param_groups'][0]['lr'])
            epoch_avg_loss = 0
            val_avg_loss = 0
            num_train_samples = 0
            num_val_samples = 0
            for i, (_, batch_x, batch_y) in enumerate(train_loader):
                loss, outputs = self.model(batch_x.float().to(self.device),
                                           batch_y.float().to(self.device))
                optimizer.zero_grad()
                epoch_avg_loss += loss.to('cpu').detach() * len(batch_x)
                num_train_samples += len(batch_x)
                loss.backward()
                optimizer.step()
            epoch_avg_loss /= num_train_samples
            print('Epoch %d: Train loss %.4f' % (epoch, epoch_avg_loss))
            scheduler.step()

            if val_loader is not None:
                self.model.eval()
                with torch.no_grad():
                    for i, (_, batch_x, batch_y) in enumerate(val_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        val_avg_loss += loss.to('cpu').detach() * len(batch_x)
                        num_val_samples += len(batch_x)

                    val_avg_loss /= num_val_samples
                    print('Epoch %d: Val loss %.4f' % (epoch, val_avg_loss))

                    # Early stop
                    if 'refit' not in mode:
                        early_stop.update(val_avg_loss)
                        if early_stop.if_early_stop:
                            self.early_stop_flag = True
                            print("Early stop!")
                            break

        self.optimizer_ = optimizer
        self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
        self.scheduler = scheduler

        return self
Example #24
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)

    torch.multiprocessing.set_sharing_strategy('file_system')

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)
    #82783
    # train_loader = get_train_loader(args, args.seed - 2**31, 118287)

    # target_loader = get_target_loader(args, args.seed - 2**31, 118287)

    train_loader = get_train_loader(args, args.seed - 2**31, 5000)

    target_loader = get_target_loader(args, args.seed - 2**31, 5000)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = DASSD300(backbone=ResNet(args.backbone, args.backbone_path))
    # ?????args.learning_rate = args.learning_rate * args.N_gpu * ((args.batch_size + args.batch_size // 2) / 32)
    args.learning_rate = args.learning_rate * args.N_gpu * (
        (args.batch_size + args.batch_size) / 32)
    start_epoch = 0
    iteration = 0
    loss_func = DALoss(dboxes)
    da_loss_func = ImageLevelAdaptationLoss()

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()
        da_loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=args.multistep,
                            gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300,
                            args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.
                                    cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    meters = {
        'total': AverageValueMeter(),
        'ssd': AverageValueMeter(),
        'da': AverageValueMeter()
    }

    vis = Visualizer(env='da ssd', port=6006)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, da_loss_func, epoch,
                                    optimizer, train_loader, target_loader,
                                    encoder, iteration, logger, args, mean,
                                    std, meters, vis)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                           args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)
                vis.log(acc, win='Evaluation')

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'label_map': val_dataset.label_info
            }
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            torch.save(obj, './models/epoch_{}.pt'.format(epoch))
        train_loader.reset()
        target_loader.reset()

    print('total training time: {}'.format(total_time))
def train(train_loop_func, logger, args):
    if args.amp:
        amp_handle = amp.init(enabled=args.fp16)
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=args.backbone)
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size /
                                                            32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    if args.fp16 and not args.amp:
        ssd300 = network_to_half(ssd300)

    if args.distributed:
        ssd300 = DDP(ssd300)

    optimizer = torch.optim.SGD(tencent_trick(ssd300),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=args.multistep,
                            gamma=0.1)
    if args.fp16:
        if args.amp:
            optimizer = amp_handle.wrap_optimizer(optimizer)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.)
    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300, args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.
                                    cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            ssd300.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer,
                                    train_loader, val_dataloader, encoder,
                                    iteration, logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                           args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'label_map': val_dataset.label_info
            }
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            torch.save(obj, './models/epoch_{}.pt'.format(epoch))
        train_loader.reset()
    print('total training time: {}'.format(total_time))
Example #26
0
def main(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)

    train_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    eval_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    dboxes = generate_dboxes()
    model = SSD()
    train_set = OIDataset(SimpleTransformer(dboxes), train=True)
    train_loader = DataLoader(train_set, **train_params)
    val_set = OIDataset(SimpleTransformer(dboxes, eval=True), validation=True)
    val_loader = DataLoader(val_set, **eval_params)

    encoder = Encoder(dboxes)

    opt.lr = opt.lr * (opt.batch_size / 32)
    criterion = Loss(dboxes)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay,
                                nesterov=True)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=opt.multistep,
                            gamma=0.1)

    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

    model = torch.nn.DataParallel(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    checkpoint_path = os.path.join(opt.save_folder, "SSD.pth")

    writer = SummaryWriter(opt.log_path)

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        first_epoch = checkpoint["epoch"] + 1
        model.module.load_state_dict(checkpoint["model_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        first_epoch = 0

    for epoch in range(first_epoch, opt.epochs):
        train(model, train_loader, epoch, writer, criterion, optimizer,
              scheduler)
        evaluate(model, val_loader, encoder, opt.nms_threshold)

        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)
Example #27
0
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\33[91mCurrent time is {}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    model = LSTM_End(input_dim=args.feat_dim,
                     num_class=train_dir.num_spks,
                     batch_size=args.batch_size * args.tuple_size,
                     project_dim=args.embedding_dim,
                     num_lstm=args.num_lstm,
                     dropout_p=0.1)

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[60], gamma=0.1)

    start = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = start + args.epochs

    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=int(args.batch_size *
                                                              args.tuple_size),
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(
        test_part,
        batch_size=int(args.batch_size * args.tuple_size /
                       args.test_input_per_file),
        shuffle=False,
        **kwargs)
    # criterion = nn.CrossEntropyLoss().cuda()
    criterion = [
        nn.CrossEntropyLoss().cuda(),
        TupleLoss(args.batch_size, args.tuple_size).cuda()
    ]
    check_path = '{}/checkpoint_{}.pth'.format(args.check_path, -1)
    torch.save(
        {
            'epoch': -1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, check_path)

    for epoch in range(start, end):
        # pdb.set_trace()
        # compute_dropout(model, optimizer, epoch, end)
        train(train_loader, model, optimizer, criterion, epoch)
        test(valid_loader, test_loader, model, epoch)
        scheduler.step()
        # break
    writer.close()
Example #28
0
class Learner:
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir, self.args.resume_from_checkpoint, False)

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile,
                      "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        self.writer = SummaryWriter()

        #gpu_device = 'cuda:0'
        gpu_device = 'cuda'
        self.device = torch.device(
            gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()

        self.vd = video_reader.VideoDataset(self.args)
        self.video_loader = torch.utils.data.DataLoader(
            self.vd, batch_size=1, num_workers=self.args.num_workers)

        self.loss = loss
        self.accuracy_fn = aggregate_accuracy

        if self.args.opt == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=self.args.learning_rate)
        elif self.args.opt == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.args.learning_rate)
        self.test_accuracies = TestAccuracies(self.test_set)

        self.scheduler = MultiStepLR(self.optimizer,
                                     milestones=self.args.sch,
                                     gamma=0.1)

        self.start_iteration = 0
        if self.args.resume_from_checkpoint:
            self.load_checkpoint()
        self.optimizer.zero_grad()

    def init_model(self):
        model = CNN_TRX(self.args)
        model = model.to(self.device)
        if self.args.num_gpus > 1:
            model.distribute_model()
        return model

    def init_data(self):
        train_set = [self.args.dataset]
        validation_set = [self.args.dataset]
        test_set = [self.args.dataset]

        return train_set, validation_set, test_set

    """
    Command line parser
    """

    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--dataset",
                            choices=["ssv2", "kinetics"],
                            default="ssv2",
                            help="Dataset to use.")
        parser.add_argument("--learning_rate",
                            "-lr",
                            type=float,
                            default=0.001,
                            help="Learning rate.")
        parser.add_argument(
            "--tasks_per_batch",
            type=int,
            default=16,
            help="Number of tasks between parameter optimizations.")
        parser.add_argument("--checkpoint_dir",
                            "-c",
                            default=None,
                            help="Directory to save checkpoint to.")
        parser.add_argument("--test_model_path",
                            "-m",
                            default=None,
                            help="Path to model to load and test.")
        parser.add_argument("--training_iterations",
                            "-i",
                            type=int,
                            default=50020,
                            help="Number of meta-training iterations.")
        parser.add_argument("--resume_from_checkpoint",
                            "-r",
                            dest="resume_from_checkpoint",
                            default=False,
                            action="store_true",
                            help="Restart from latest checkpoint.")
        parser.add_argument("--way",
                            type=int,
                            default=5,
                            help="Way of single dataset task.")
        parser.add_argument(
            "--shot",
            type=int,
            default=1,
            help="Shots per class for context of single dataset task.")
        parser.add_argument("--query_per_class",
                            type=int,
                            default=5,
                            help="Target samples (i.e. queries) per class.")

        parser.add_argument("--seq_len",
                            type=int,
                            default=8,
                            help="Frames per video.")
        parser.add_argument("--num_workers",
                            type=int,
                            default=10,
                            help="Num dataloader workers.")
        parser.add_argument("--method",
                            choices=["resnet18", "resnet34", "resnet50"],
                            default="resnet50",
                            help="method")
        parser.add_argument("--trans_linear_out_dim",
                            type=int,
                            default=1152,
                            help="Transformer linear_out_dim")
        parser.add_argument("--opt",
                            choices=["adam", "sgd"],
                            default="sgd",
                            help="Optimizer")
        parser.add_argument("--trans_dropout",
                            type=int,
                            default=0.1,
                            help="Transformer dropout")
        parser.add_argument(
            "--save_freq",
            type=int,
            default=5000,
            help="Number of iterations between checkpoint saves.")
        parser.add_argument("--img_size",
                            type=int,
                            default=224,
                            help="Input image size to the CNN after cropping.")
        parser.add_argument('--temp_set',
                            nargs='+',
                            type=int,
                            help='cardinalities e.g. 2,3 is pairs and triples',
                            default=[2, 3])

        parser.add_argument("--scratch",
                            choices=["bc", "bp"],
                            default="bp",
                            help="Computer to run on")
        parser.add_argument("--num_gpus",
                            type=int,
                            default=1,
                            help="Number of GPUs to split the ResNet over")
        parser.add_argument("--debug_loader",
                            default=False,
                            action="store_true",
                            help="Load 1 vid per class for debugging")

        parser.add_argument("--split",
                            type=int,
                            default=3,
                            help="Dataset split.")
        parser.add_argument('--sch',
                            nargs='+',
                            type=int,
                            help='iters to drop learning rate',
                            default=[1000000])

        args = parser.parse_args()

        if args.scratch == "bc":
            args.scratch = "/mnt/storage/home/tp8961/scratch"
        elif args.scratch == "bp":
            args.num_gpus = 4
            args.num_workers = 5
            args.scratch = "/work/tp8961"

        if args.checkpoint_dir == None:
            print("need to specify a checkpoint dir")
            exit(1)

        if (args.method == "resnet50") or (args.method == "resnet34"):
            args.img_size = 224
        if args.method == "resnet50":
            args.trans_linear_in_dim = 2048
        else:
            args.trans_linear_in_dim = 512

        if args.dataset == "ssv2":
            args.traintestlist = os.path.join(
                args.scratch,
                "video_datasets/splits/somethingsomethingv2TrainTestlist")
            args.path = os.path.join(
                args.scratch,
                "video_datasets/data/somethingsomethingv2_256x256q5_1.zip")
        elif args.dataset == "kinetics":
            args.traintestlist = os.path.join(
                args.scratch, "video_datasets/splits/kineticsTrainTestlist")
            args.path = os.path.join(
                args.scratch, "video_datasets/data/kinetics_256q5_1.zip")
        return args

    def run(self):
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.compat.v1.Session(config=config) as session:
            train_accuracies = []
            losses = []
            total_iterations = self.args.training_iterations

            iteration = self.start_iteration
            for task_dict in self.video_loader:
                if iteration >= total_iterations:
                    break
                iteration += 1
                torch.set_grad_enabled(True)

                task_loss, task_accuracy = self.train_task(task_dict)
                train_accuracies.append(task_accuracy)
                losses.append(task_loss)

                # optimize
                if ((iteration + 1) % self.args.tasks_per_batch
                        == 0) or (iteration == (total_iterations - 1)):
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                self.scheduler.step()
                if (iteration + 1) % PRINT_FREQUENCY == 0:
                    # print training stats
                    print_and_log(
                        self.logfile,
                        'Task [{}/{}], Train Loss: {:.7f}, Train Accuracy: {:.7f}'
                        .format(iteration + 1, total_iterations,
                                torch.Tensor(losses).mean().item(),
                                torch.Tensor(train_accuracies).mean().item()))
                    train_accuracies = []
                    losses = []

                if ((iteration + 1) % self.args.save_freq
                        == 0) and (iteration + 1) != total_iterations:
                    self.save_checkpoint(iteration + 1)

                if ((iteration + 1)
                        in TEST_ITERS) and (iteration + 1) != total_iterations:
                    accuracy_dict = self.test(session)
                    self.test_accuracies.print(self.logfile, accuracy_dict)

            # save the final model
            torch.save(self.model.state_dict(), self.checkpoint_path_final)

        self.logfile.close()

    def train_task(self, task_dict):
        context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list = self.prepare_task(
            task_dict)

        model_dict = self.model(context_images, context_labels, target_images)
        target_logits = model_dict['logits']

        task_loss = self.loss(target_logits, target_labels,
                              self.device) / self.args.tasks_per_batch
        task_accuracy = self.accuracy_fn(target_logits, target_labels)

        task_loss.backward(retain_graph=False)

        return task_loss, task_accuracy

    def test(self, session):
        self.model.eval()
        with torch.no_grad():

            self.video_loader.dataset.train = False
            accuracy_dict = {}
            accuracies = []
            iteration = 0
            item = self.args.dataset
            for task_dict in self.video_loader:
                if iteration >= NUM_TEST_TASKS:
                    break
                iteration += 1

                context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list = self.prepare_task(
                    task_dict)
                model_dict = self.model(context_images, context_labels,
                                        target_images)
                target_logits = model_dict['logits']
                accuracy = self.accuracy_fn(target_logits, target_labels)
                accuracies.append(accuracy.item())
                del target_logits

            accuracy = np.array(accuracies).mean() * 100.0
            confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(
                len(accuracies))

            accuracy_dict[item] = {
                "accuracy": accuracy,
                "confidence": confidence
            }
            self.video_loader.dataset.train = True
        self.model.train()

        return accuracy_dict

    def prepare_task(self, task_dict, images_to_device=True):
        context_images, context_labels = task_dict['support_set'][
            0], task_dict['support_labels'][0]
        target_images, target_labels = task_dict['target_set'][0], task_dict[
            'target_labels'][0]
        real_target_labels = task_dict['real_target_labels'][0]
        batch_class_list = task_dict['batch_class_list'][0]

        if images_to_device:
            context_images = context_images.to(self.device)
            target_images = target_images.to(self.device)
        context_labels = context_labels.to(self.device)
        target_labels = target_labels.type(torch.LongTensor).to(self.device)

        return context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list

    def shuffle(self, images, labels):
        """
        Return shuffled data.
        """
        permutation = np.random.permutation(images.shape[0])
        return images[permutation], labels[permutation]

    def save_checkpoint(self, iteration):
        d = {
            'iteration': iteration,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }

        torch.save(
            d,
            os.path.join(self.checkpoint_dir,
                         'checkpoint{}.pt'.format(iteration)))
        torch.save(d, os.path.join(self.checkpoint_dir, 'checkpoint.pt'))

    def load_checkpoint(self):
        checkpoint = torch.load(
            os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
        self.start_iteration = checkpoint['iteration']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
Example #29
0
def main():
    global best_RMSE

    lw = utils_func.LossWise(args.api_key, args.losswise_tag, args.epochs - 1)
    # set logger
    log = logger.setup_logger(os.path.join(args.save_path, 'training.log'))
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    # set tensorboard
    writer = SummaryWriter(args.save_path + '/tensorboardx')

    # Data Loader
    if args.generate_depth_map:
        TrainImgLoader = None
        import dataloader.KITTI_submission_loader as KITTI_submission_loader
        TestImgLoader = torch.utils.data.DataLoader(
            KITTI_submission_loader.SubmiteDataset(args.datapath,
                                                   args.data_list,
                                                   args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=args.workers,
            drop_last=False)
    elif args.dataset == 'kitti':
        train_data, val_data = KITTILoader3D.dataloader(
            args.datapath,
            args.split_train,
            args.split_val,
            kitti2015=args.kitti2015)
        TrainImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(train_data,
                                                True,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
        TestImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(val_data,
                                                False,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
    else:
        train_data, val_data = listflowfile.dataloader(args.datapath)
        TrainImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(train_data,
                                          True,
                                          calib=args.calib_value),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False)
        TestImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(val_data,
                                          False,
                                          calib=args.calib_value),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False)

    # Load Model
    if args.data_type == 'disparity':
        model = disp_models.__dict__[args.arch](maxdisp=args.maxdisp)
    elif args.data_type == 'depth':
        model = models.__dict__[args.arch](maxdepth=args.maxdepth,
                                           maxdisp=args.maxdisp,
                                           down=args.down,
                                           scale=args.scale)
    else:
        log.info('Model is not implemented')
        assert False

    # Number of parameters
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    model = nn.DataParallel(model).cuda()
    torch.backends.cudnn.benchmark = True

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_stepsize,
                            gamma=args.lr_gamma)

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            log.info("=> loading pretrain '{}'".format(args.pretrain))
            checkpoint = torch.load(args.pretrain)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.pretrain))

    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_RMSE = checkpoint['best_RMSE']
            scheduler.load_state_dict(checkpoint['scheduler'])
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.resume))

    if args.generate_depth_map:
        os.makedirs(args.save_path + '/depth_maps/' + args.data_tag,
                    exist_ok=True)

        tqdm_eval_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, calib, H, W,
                        filename) in enumerate(tqdm_eval_loader):
            pred_disp = inference(imgL_crop, imgR_crop, calib, model)
            for idx, name in enumerate(filename):
                np.save(
                    args.save_path + '/depth_maps/' + args.data_tag + '/' +
                    name, pred_disp[idx][-H[idx]:, :W[idx]])
        import sys
        sys.exit()

    # evaluation
    if args.evaluate:
        evaluate_metric = utils_func.Metric()
        ## training ##
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(TestImgLoader):
            start_time = time.time()
            test(imgL_crop, imgR_crop, disp_crop_L, calib, evaluate_metric,
                 optimizer, model)

            log.info(
                evaluate_metric.print(batch_idx, 'EVALUATE') +
                ' Time:{:.3f}'.format(time.time() - start_time))
        import sys
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        ## training ##
        train_metric = utils_func.Metric()
        tqdm_train_loader = tqdm(TrainImgLoader, total=len(TrainImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(tqdm_train_loader):
            # start_time = time.time()
            train(imgL_crop, imgR_crop, disp_crop_L, calib, train_metric,
                  optimizer, model, epoch)
            # log.info(train_metric.print(batch_idx, 'TRAIN') + ' Time:{:.3f}'.format(time.time() - start_time))
        log.info(train_metric.print(0, 'TRAIN Epoch' + str(epoch)))
        train_metric.tensorboard(writer, epoch, token='TRAIN')
        lw.update(train_metric.get_info(), epoch, 'Train')

        ## testing ##
        is_best = False
        if epoch == 0 or ((epoch + 1) % args.eval_interval) == 0:
            test_metric = utils_func.Metric()
            tqdm_test_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
            for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                            calib) in enumerate(tqdm_test_loader):
                # start_time = time.time()
                test(imgL_crop, imgR_crop, disp_crop_L, calib, test_metric,
                     optimizer, model)
                # log.info(test_metric.print(batch_idx, 'TEST') + ' Time:{:.3f}'.format(time.time() - start_time))
            log.info(test_metric.print(0, 'TEST Epoch' + str(epoch)))
            test_metric.tensorboard(writer, epoch, token='TEST')
            lw.update(test_metric.get_info(), epoch, 'Test')

            # SAVE
            is_best = test_metric.RMSELIs.avg < best_RMSE
            best_RMSE = min(test_metric.RMSELIs.avg, best_RMSE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_RMSE': best_RMSE,
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            epoch,
            folder=args.save_path)
    lw.done()
Example #30
0
def main():
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    cudnn.benchmark = True
    cudnn.enabled = True

    logger = get_logger(__name__, Config.log)

    Config.gpus = torch.cuda.device_count()
    logger.info("use {} gpus".format(Config.gpus))
    config = {
        key: value
        for key, value in Config.__dict__.items() if not key.startswith("__")
    }
    logger.info(f"args: {config}")

    start_time = time.time()

    # dataset and dataloader
    logger.info("start loading data")

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = ImageFolder(Config.train_dataset_path, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_dataset = ImageFolder(Config.val_dataset_path, val_transform)
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    logger.info("finish loading data")

    # network
    net = ChannelDistillResNet1834(Config.num_classes, Config.dataset_type)
    net = nn.DataParallel(net).cuda()

    # loss and optimizer
    criterion = []
    for loss_item in Config.loss_list:
        loss_name = loss_item["loss_name"]
        loss_type = loss_item["loss_type"]
        if "kd" in loss_type:
            criterion.append(losses.__dict__[loss_name](loss_item["T"]).cuda())
        else:
            criterion.append(losses.__dict__[loss_name]().cuda())

    optimizer = SGD(net.parameters(),
                    lr=Config.lr,
                    momentum=0.9,
                    weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)

    # only evaluate
    if Config.evaluate:
        # load best model
        if not os.path.isfile(Config.evaluate):
            raise Exception(
                f"{Config.evaluate} is not a file, please check it again")
        logger.info("start evaluating")
        logger.info(f"start resuming model from {Config.evaluate}")
        checkpoint = torch.load(Config.evaluate,
                                map_location=torch.device("cpu"))
        net.load_state_dict(checkpoint["model_state_dict"])
        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )
        return

    start_epoch = 1
    # resume training
    if os.path.exists(Config.resume):
        logger.info(f"start resuming model from {Config.resume}")
        checkpoint = torch.load(Config.resume,
                                map_location=torch.device("cpu"))
        start_epoch += checkpoint["epoch"]
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        logger.info(
            f"finish resuming model from {Config.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc']}%, loss {checkpoint['loss']}%")

    if not os.path.exists(Config.checkpoints):
        os.makedirs(Config.checkpoints)

    logger.info("start training")
    best_acc = 0.
    for epoch in range(start_epoch, Config.epochs + 1):
        prec1, prec5, loss = train(train_loader, net, criterion, optimizer,
                                   scheduler, epoch, logger)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                "epoch": epoch,
                "acc": prec1,
                "loss": loss,
                "lr": scheduler.get_lr()[0],
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
            }, os.path.join(Config.checkpoints, "latest.pth"))
        if prec1 > best_acc:
            shutil.copyfile(os.path.join(Config.checkpoints, "latest.pth"),
                            os.path.join(Config.checkpoints, "best.pth"))
            best_acc = prec1

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, best acc: {best_acc:.2f}%, total training time: {training_time:.2f} hours"
    )