Exemplo n.º 1
0
                                weight_decay=opt.weight_decay)

        milestones = [80, 120]

    else:
        """
                For other datasets
        """
        raise NotImplementedError

    dict_best_top1 = {'Epoch': 0, 'Top1': 100.}
    dict_best_top5 = {'Epoch': 0, 'Top5': 100.}

    if opt.resume:
        state_dict = torch.load(opt.path_model)
        model.load_state_dict(state_dict['state_dict'])
        optim.load_state_dict(state_dict['optimizer'])

        dict_best_top1.update({'Epoch': opt.epoch_top1, 'Top1': opt.top1})
        dict_best_top5.update({'Epoch': opt.epoch_top5, 'Top5': opt.top5})

    st = datetime.now()
    iter_total = 0
    top1_hist = list(100 for i in range(100))
    top5_hist = list(100 for i in range(100))  # to see 100 latest top5 error
    for epoch in range(opt.epoch_recent, opt.epochs):
        adjust_lr(optim, epoch, opt.lr, milestones=milestones, gamma=0.1)
        list_loss = list()
        model.train()
        for input, label in tqdm(data_loader):
            iter_total += 1
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="train")
    parser.add_argument("--model", type=str, default="mobilenet_v2")
    parser.add_argument("--dataset", type=str, default="cifar10")
    parser.add_argument("--dataroot", type=str, default="/tmp/data")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--n_epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--n_gpus", type=int, default=1)
    parser.add_argument("--checkpoint", type=str, default="/tmp/chkpt.pth.tar")
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--pretrained", type=str, default=None)
    args = parser.parse_args()
    print(args)

    if torch.cuda.is_available():
        print("cuda is available, use cuda")
        device = torch.device("cuda")
    else:
        print("cuda is not available, use cpu")
        device = torch.device("cpu")

    print("download dataset: {}".format(args.dataset))
    train_loader, test_loader, n_classes = get_loaders(
        dataset=args.dataset, root=args.dataroot, batch_size=args.batch_size)

    print("build model: {}".format(args.model))
    if args.model == "mobilenet":
        from models import MobileNet
        model = MobileNet(n_classes=n_classes)
    elif args.model == "mobilenet_v2":
        from models import MobileNet_v2
        model = MobileNet_v2(n_classes=n_classes)
    elif args.model == "shufflenet":
        from models import ShuffleNet
        model = ShuffleNet(n_classes=n_classes)
    elif args.model == "shufflenet_v2":
        from models import ShuffleNet_v2
        model = ShuffleNet_v2(n_classes=n_classes)
    elif args.model == "squeezenet":
        from models import SqueezeNet
        model = SqueezeNet(n_classes=n_classes)
    else:
        raise NotImplementedError

    model = model.to(device)
    if args.pretrained:
        model.load_state_dict(torch.load(args.checkpoint))

    if args.n_gpus > 1:
        gpus = []
        for i in range(args.n_gpus):
            gpus.append(i)
        model = nn.DataParallel(model, device_ids=gpus)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    if args.mode == "train":
        for epoch in range(args.n_epochs):
            train(epoch, model, optimizer, criterion, train_loader, device)
            if (epoch + 1) % args.save_every == 0:
                print("saving model...")
                torch.save(the_model.state_dict(), args.checkpoint)
    elif args.mode == "test":
        test(model, criterion, test_loader, device)
    else:
        raise NotImplementedError
Exemplo n.º 3
0
        transforms.Normalize(mean=[0.5],std=[0.5]),
    ])
# Argv
test_fpath = sys.argv[1]
model_fpath = sys.argv[2]
output_fpath = sys.argv[3]
print('# [Info] Argv')
print('    - Test   : {}'.format(test_fpath))
print('    - Model  : {}'.format(model_fpath))
print('    = Output : {}'.format(output_fpath))
# Make data loader
test_dataset = MyDataset(filename=test_fpath, is_train=False, transform=test_transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) 
# Load model
model = MobileNet()
model.load_state_dict(torch.load(model_fpath, map_location=device))
model.to(device)
# Model prediction
model.eval()
prediction = []
for i, data in enumerate(test_loader):
    data_device = data.to(device)   
    output = model(data_device)
    labels = torch.max(output, 1)[1]
    for label in labels:
        prediction.append(label)
# Output prediction
print('# [Info] Output prediction: {}'.format(output_fpath))
with open(output_fpath, 'w') as f:
    f.write('id,label\n')
    for i, v in enumerate(prediction):