示例#1
0
文件: train.py 项目: liuguoyou/moco
def main(args):
    train_loader = get_loader(args)
    n_data = len(train_loader.dataset)
    logger.info(f"length of training dataset: {n_data}")

    model, model_ema = build_model(args)
    contrast = MemoryMoCo(128, args.nce_k, args.nce_t).cuda()
    criterion = NCESoftmaxLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.batch_size * dist.get_world_size() / 256 * args.base_learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = get_scheduler(optimizer, len(train_loader), args)

    if args.amp_opt_level != "O0":
        if amp is None:
            logger.warning(f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n"
                           "you should install apex from https://github.com/NVIDIA/apex#quick-start first")
            args.amp_opt_level = "O0"
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)
            model_ema = amp.initialize(model_ema, opt_level=args.amp_opt_level)

    model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler)

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):
        train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        loss, prob = train_moco(epoch, train_loader, model, model_ema, contrast, criterion, optimizer, scheduler, args)

        logger.info('epoch {}, total time {:.2f}'.format(epoch, time.time() - tic))

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('ins_prob', prob, epoch)
            summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if dist.get_rank() == 0:
            # save model
            save_checkpoint(args, epoch, model, model_ema, contrast, scheduler, optimizer)
示例#2
0
文件: train.py 项目: atch841/moco
def main(args):
    train_loader = get_loader(args, 'train5/')
    train_loader_seg = get_loader(args, 'train5_1p_half/')
    n_data = len(train_loader.dataset)
    n_data_seg = len(train_loader_seg.dataset)
    logger.info(f"length of training dataset: {n_data} {n_data_seg}")

    model, model_ema = build_model(args)
    if args.model == 'resnet50':
        contrast = MemoryMoCo(128, 300, args.nce_t).cuda()
    elif args.model == 'vit':
        contrast = MemoryMoCo(128, 200, args.nce_t, s=64, c=768).cuda()
    elif args.model == 'resnet101':
        contrast = MemoryMoCo(128, 300, args.nce_t).cuda()

    criterion = NCESoftmaxLoss().cuda()
    # optimizer = torch.optim.SGD(model.parameters(),
    # optimizer = torch.optim.SGD([{'params': model.backbone.parameters(), 'lr': args.batch_size / 256 * args.base_learning_rate * 0.8},
    optimizer = torch.optim.SGD([{
        'params': model.backbone.parameters()
    }, {
        'params': model.mlp.parameters()
    }, {
        'params': model.mlp2.parameters()
    }],
                                lr=args.batch_size / 256 *
                                args.base_learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = get_scheduler(optimizer, len(train_loader), args)
    optimizer_seg = torch.optim.SGD(
        [{
            'params': model.backbone.parameters()
        }, {
            'params': model.decoder.parameters()
        }, {
            'params': model.segmentation_head.parameters()
        }],
        lr=args.batch_size / 256 * args.base_learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    scheduler_seg = get_scheduler(optimizer_seg, len(train_loader_seg), args)

    if args.amp_opt_level != "O0":
        if amp is None:
            logger.warning(
                f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n"
                "you should install apex from https://github.com/NVIDIA/apex#quick-start first"
            )
            args.amp_opt_level = "O0"
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.amp_opt_level)
            # model_ema = amp.initialize(model_ema, opt_level=args.amp_opt_level)

    # model = DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler)

    # tensorboard
    summary_writer = SummaryWriter(log_dir=args.output_dir)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):
        # train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        loss, loss_g, loss_d = train_moco(epoch, train_loader, model,
                                          model_ema, contrast, criterion,
                                          optimizer, scheduler, args)
        loss_seg = train_seg(epoch, train_loader_seg, model, optimizer_seg,
                             scheduler_seg, args)

        logger.info('epoch {}, total time {:.2f}'.format(
            epoch,
            time.time() - tic))

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', loss, epoch)
            summary_writer.add_scalar('ins_loss_g', loss_g, epoch)
            summary_writer.add_scalar('ins_loss_d', loss_d, epoch)
            summary_writer.add_scalar('ins_loss_seg', loss_seg, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)

        # validate
        if epoch % args.eval_freq == 0:
            dsc = inference(model)
            summary_writer.add_scalar('dice_percase', dsc, epoch)
            logger.info(f'validate result {epoch}: {dsc}')

        # save model
        save_checkpoint(args, epoch, model, model_ema, contrast, optimizer,
                        scheduler)
def main(args):
    global best_acc1

    train_loader, val_loader = get_loader(args)
    logger.info(f"length of training dataset: {len(train_loader.dataset)}")

    model, classifier = build_model(args,
                                    num_class=len(
                                        train_loader.dataset.classes))
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(classifier.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = get_scheduler(optimizer, len(train_loader), args)

    if args.amp_opt_level != "O0":
        if amp is None:
            logger.warning(
                f"apex is not installed but amp_opt_level is set to {args.amp_opt_level}, ignoring.\n"
                "you should install apex from https://github.com/NVIDIA/apex#quick-start first"
            )
            args.amp_opt_level = "O0"
        else:
            model = amp.initialize(model, opt_level=args.amp_opt_level)
            classifier, optimizer = amp.initialize(
                classifier, optimizer, opt_level=args.amp_opt_level)

    classifier = DistributedDataParallel(classifier,
                                         device_ids=[args.local_rank],
                                         broadcast_buffers=False)

    model.eval()

    load_pretrained(args, model)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(
            args.resume), f"no checkpoint found at '{args.resume}'"
        load_checkpoint(args, classifier, optimizer, scheduler)

    if args.eval:
        logger.info("==> testing...")
        validate(val_loader, model, classifier, criterion, args)
        return

    # tensorboard
    if dist.get_rank() == 0:
        summary_writer = SummaryWriter(log_dir=args.output_dir)
    else:
        summary_writer = None

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):
        if isinstance(train_loader.sampler, DistributedSampler):
            train_loader.sampler.set_epoch(epoch)

        tic = time.time()
        train(epoch, train_loader, model, classifier, criterion, optimizer,
              scheduler, args)
        logger.info(f'epoch {epoch}, total time {time.time() - tic:.2f}')

        logger.info("==> testing...")
        test_acc = validate(val_loader, model, classifier, criterion, args)

        if dist.get_rank() == 0 and epoch % args.save_freq == 0:
            logger.info('==> Saving...')
            save_checkpoint(args, epoch, classifier, test_acc, optimizer,
                            scheduler)

        if summary_writer is not None:
            # tensorboard logger
            summary_writer.add_scalar('ins_loss', test_acc, epoch)
            summary_writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'], epoch)