Ejemplo n.º 1
0
def main():

    global args, best_prec1

    best_prec1 = 1e6

    args = parser.parse_args()
    args.lr = 1e-4
    args.batch_size = 26
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 1000
    args.workers = 4
    args.seed = int(time.time())
    args.print_freq = 4
    with open(args.train_json, 'r') as outfile:
        train_list = json.load(outfile)
    with open(args.val_json, 'r') as outfile:
        val_list = json.load(outfile)

    torch.cuda.manual_seed(args.seed)

    model = CANNet()

    model = model.cuda()

    criterion = nn.MSELoss(size_average=False).cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.decay)

    for epoch in range(args.start_epoch, args.epochs):
        train(train_list, model, criterion, optimizer, epoch)
        prec1 = validate(val_list, model, criterion)

        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))
        save_checkpoint({
            'state_dict': model.state_dict(),
        }, is_best)
Ejemplo n.º 2
0
def main():
    print("control is at main")

    global args, best_prec1

    best_prec1 = 1e6

    args = parser.parse_args()
    args.lr = 1e-4
    args.batch_size = 26
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 1000
    args.workers = 4
    args.seed = int(time.time())
    args.print_freq = 4
    with open(args.train_json, 'r') as outfile:
        train_list = json.load(outfile)
    with open(args.val_json, 'r') as outfile:
        val_list = json.load(outfile)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = CANNet()

    # if torch.cuda.is_available():
    #     print("if")
    #     device = 'cuda'
    # else:
    #     print("else")
    #     device = 'cpu'
    # model.to(device)
    # if torch.cuda.is_available():
    #     device = torch.device('cuda')
    # else:
    #     device = torch.device('cpu')
    #     print("Device",device)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Device", device)

    #model = model.to(device)

    print("running")

    criterion = nn.MSELoss(size_average=False)
    print("running after")

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.decay)

    for epoch in range(args.start_epoch, args.epochs):
        print("train method called", train)
        train(train_list, model, criterion, optimizer, epoch)
        prec1 = validate(val_list, model, criterion)

        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))
        save_checkpoint({
            'state_dict': model.state_dict(),
        }, is_best)
Ejemplo n.º 3
0
        print('\t', key, '-->', args_map[key])
    # add one more empty line for better output
    print()


if __name__ == "__main__":
    # Parse arguments
    args = make_args_parser()
    print_args(args)
    # Initialize model, loss function and optimizer
    seed = time.time()
    device = torch.device(args.device)
    torch.cuda.manual_seed(seed)
    model = CANNet().to(device)
    criterion = nn.MSELoss(size_average=False).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=0)
    print("Model loaded")
    print(model)

    # Load train dataset
    train_root = os.path.join(args.root, 'train_data', 'images')
    train_loader = torch.utils.data.DataLoader(ShanghaiTechPartA(
        train_root,
        shuffle=True,
        transform=transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225]),
        downsample=8),
                                               batch_size=args.batch_size)
Ejemplo n.º 4
0
def main():

    global args, best_prec1

    best_prec1 = 1e6

    args = parser.parse_args()
    args.lr = 1e-4
    args.batch_size = 12  #26
    args.decay = 5 * 1e-4
    args.start_epoch = 0
    args.epochs = 1000
    args.workers = 4
    args.seed = int(time.time())
    args.print_freq = 4

    with open(args.train_json, 'r') as outfile:
        train_list = json.load(outfile)

    print len(train_list)

    with open(args.val_json, 'r') as outfile:
        val_list = json.load(outfile)

    print len(val_list)

    torch.cuda.manual_seed(args.seed)

    model = CANNet()

    model = model.cuda()

    criterion = nn.MSELoss(size_average=False).cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.decay)

    ###########
    if args.best:
        print("=> loading best checkpoint '{}'".format(args.best))

        checkpoint = torch.load(os.path.join(args.output,
                                             'model_best.pth.tar'))

        model.load_state_dict(checkpoint['state_dict'])

        best_prec1 = validate(val_list, model, criterion)

        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))

    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))

    if args.initial:
        args.start_epoch = args.initial
        print(args.initial)
        print(x)

    for epoch in range(args.start_epoch, args.epochs):

        train(train_list, model, criterion, optimizer, epoch)

        prec1 = validate(val_list, model, criterion)

        is_best = prec1 < best_prec1

        best_prec1 = min(prec1, best_prec1)

        print(' * best MAE {mae:.3f} '.format(mae=best_prec1))

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            args.output,
            filename='checkpoint.pth.tar')

    print('Train process finished')