def main():
    global args, best_prec1, min_loss
    args = parser.parse_args()

    rank, world_size = dist_init(args.port)
    print("world_size is: {}".format(world_size))
    assert (args.batch_size % world_size == 0)
    assert (args.workers % world_size == 0)
    args.batch_size = args.batch_size // world_size
    args.workers = args.workers // world_size

    # create model
    print("=> creating model '{}'".format("inceptionv4"))
    print("save_path is: {}".format(args.save_path))

    image_size = 341
    input_size = 299
    model = get_model('inceptionv4', pretrained=True)
    # print("model is: {}".format(model))
    model.cuda()
    model = DistModule(model)

    # optionally resume from a checkpoint
    if args.load_path:
        if args.resume_opt:
            best_prec1, start_epoch = load_state(args.load_path,
                                                 model,
                                                 optimizer=optimizer)
        else:
            # print('load weights from', args.load_path)
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_dataset = McDataset(
        args.train_root, args.train_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = McDataset(
        args.val_root, args.val_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    lr = 0
    patience = 0
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch)
        train_sampler.set_epoch(epoch)

        if epoch == 1:
            lr = 0.00003
        if patience == 2:
            patience = 0
            checkpoint = load_checkpoint(args.save_path + '_best.pth.tar')
            model.load_state_dict(checkpoint['state_dict'])
            print("Loading checkpoint_best.............")
            # model.load_state_dict(torch.load('checkpoint_best.pth.tar'))
            lr = lr / 10.0

        if epoch == 0:
            lr = 0.001
            for name, param in model.named_parameters():
                # print("name is: {}".format(name))
                if (name not in last_layer_names):
                    param.requires_grad = False
            optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                            lr=lr)
            # optimizer = torch.optim.Adam(
            #     filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        else:
            for param in model.parameters():
                param.requires_grad = True
            optimizer = torch.optim.RMSprop(model.parameters(),
                                            lr=lr,
                                            weight_decay=0.0001)
            # optimizer = torch.optim.Adam(
            #     model.parameters(), lr=lr, weight_decay=0.0001)
        print("lr is: {}".format(lr))
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_prec1, val_losses = validate(val_loader, model, criterion)
        print("val_losses is: {}".format(val_losses))
        # remember best prec@1 and save checkpoint
        if rank == 0:
            # remember best prec@1 and save checkpoint
            if val_losses < min_loss:
                is_best = True
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': 'inceptionv4',
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, args.save_path)
                # torch.save(model.state_dict(), 'best_val_weight.pth')
                print(
                    'val score improved from {:.5f} to {:.5f}. Saved!'.format(
                        min_loss, val_losses))

                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        if rank == 1 or rank == 2 or rank == 3 or rank == 4 or rank == 5 or rank == 6 or rank == 7:
            if val_losses < min_loss:
                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        print("patience is: {}".format(patience))
        print("min_loss is: {}".format(min_loss))
    print("min_loss is: {}".format(min_loss))
Ejemplo n.º 2
0
def main():
    global args, best_prec1, timer
    args = parser.parse_args()
    rank, world_size = dist_init(args.port)
    assert (args.batch_size % world_size == 0)
    assert (args.workers % world_size == 0)
    args.batch_size = args.batch_size // world_size
    args.workers = args.workers // world_size

    # step1: create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.startswith('inception_v3'):
        print('inception_v3 without aux_logits!')
        image_size = 341
        input_size = 299
        model = models.__dict__[args.arch](aux_logits=False)
    elif args.arch.startswith('ir18'):
        image_size = 640
        input_size = 448
        model = IR18()
    else:
        image_size = 256
        input_size = 224
        model = models.__dict__[args.arch]()

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        if os.path.isfile(args.pretrained):
            print("=> loading pretrained_model '{}'".format(args.pretrained))
            pretrained_model = torch.load(args.pretrained)
            model.load_state_dict(pretrained_model['state_dict'], strict=False)
            print("=> loaded pretrained_model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))
    model.cuda()
    model = DistModule(model)

    # step2: define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # step3: Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = McDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # ColorAugmentation(),
            # normalize,
        ]))
    val_dataset = McDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            # normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    timer = Timer(
        len(train_loader) + len(val_loader), args.epochs - args.start_epoch)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save_path)
            print('* Best Prec 1: {best:.3f}'.format(best=best_prec1))