Пример #1
0
def main(args):
    global best_acc
    global idx

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.JointsMSELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](is_train=True,
                                                    **vars(args))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # evaluation only
    if args.evaluate:
        print('\nEvaluation only')
        loss, acc, predictions = validate(val_loader, model, criterion,
                                          njoints, args.debug, args.flip)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = validate(val_loader, model,
                                                      criterion, njoints,
                                                      args.debug, args.flip)

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Пример #2
0
def main(args):
    global best_acc

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=args.num_classes)

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(model.parameters(), 
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'mpii-' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:        
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
                      sigma=args.sigma, label_type=args.label_type),
        batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    val_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json', 'data/mpii/images',
                      sigma=args.sigma, label_type=args.label_type, train=False),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        print('\nEvaluation only') 
        loss, acc, predictions = validate(val_loader, model, criterion, args.num_classes, args.debug, args.flip)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        from time import sleep
        sleep(2)

        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *=  args.sigma_decay
            val_loader.dataset.sigma *=  args.sigma_decay

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = validate(val_loader, model, criterion, args.num_classes,
                                                      args.debug, args.flip)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Пример #3
0
def main(args):
    global best_acc
    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    _logger = log.get_logger(__name__, args)
    _logger.info(print_args(args))

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=len(args.index_classes))

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = models.loss.UniLoss(a_points=args.a_points)
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'mpii-' + args.arch
    if args.resume:
        if isfile(args.resume):
            _logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr
                print(param_group['lr'])
            _logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=False)
            logger.set_names([
                'Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'
            ])
        else:
            _logger.info("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    _logger.info('    Total params: %.2fM' %
                 (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Data loading code
    train_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json',
                      'data/mpii/images',
                      sigma=args.sigma,
                      label_type=args.label_type,
                      _idx=args.index_classes,
                      direct=True,
                      n_points=args.n_points),
        batch_size=args.train_batch,
        shuffle=True,
        collate_fn=datasets.mpii.mycollate,
        num_workers=args.workers,
        pin_memory=False)

    val_loader = torch.utils.data.DataLoader(
        datasets.Mpii('data/mpii/mpii_annotations.json',
                      'data/mpii/images',
                      sigma=args.sigma,
                      label_type=args.label_type,
                      _idx=args.index_classes,
                      train=False,
                      direct=True),
        batch_size=args.test_batch,
        shuffle=False,
        collate_fn=datasets.mpii.mycollate,
        num_workers=args.workers,
        pin_memory=False)

    if args.evaluate:
        _logger.warning('\nEvaluation only')
        loss, acc, predictions = validate(val_loader,
                                          model,
                                          criterion,
                                          len(args.index_classes),
                                          False,
                                          args.flip,
                                          _logger,
                                          evaluate_only=True)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    # multi-thread
    inqueues = []
    outqueues = []
    valid_accs = []
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        _logger.warning('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay
        # train for one epoch
        train_loss, train_acc = train(inqueues, outqueues, train_loader, model,
                                      criterion, optimizer, args.debug,
                                      args.flip, args.clip, _logger)
        # evaluate on validation set
        with torch.no_grad():
            valid_loss, valid_acc, predictions = validate(
                val_loader, model, criterion, len(args.index_classes),
                args.debug, args.flip, _logger)
        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        valid_accs.append(valid_acc)
        if args.schedule[0] == -1:
            if len(valid_accs) > 8:
                if sum(valid_accs[-4:]) / 4 * 0.99 < sum(
                        valid_accs[-8:-4]) / 4:
                    args.schedule.append(epoch + 1)
                    valid_accs = []
        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=1)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Пример #4
0
def main(args):
    num_datasets = len(args.data_dir)  #number of datasets
    for item in [
            args.training_set_percentage, args.meta_dir, args.anno_type,
            args.ratio
    ]:
        if len(item) == 1:
            for i in range(num_datasets - 1):
                item.append(item[0])
        assert len(item) == num_datasets

    scales = [0.7, 0.85, 1, 1.3, 1.6]

    if args.meta_dir == '':
        args.meta_dir = args.data_dir  #if not specified, assume meta info is stored in data dir.

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    #create the log file not exist
    file = open(join(args.checkpoint, 'log.txt'), 'w+')
    file.close()

    if args.evaluate:  #creatng path for evaluation
        if not isdir(args.save_result_dir):
            mkdir_p(args.save_result_dir)

        folders_to_create = ['preds', 'visualization']
        if args.save_heatmap:
            folders_to_create.append('heatmaps')
        for folder_name in folders_to_create:
            if not os.path.isdir(
                    os.path.join(args.save_result_dir, folder_name)):
                print('creating path: ' +
                      os.path.join(args.save_result_dir, folder_name))
                os.mkdir(os.path.join(args.save_result_dir, folder_name))

    idx = range(args.num_classes)
    global best_acc

    cams = ['FusionCameraActor3_2']

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=args.num_classes)

    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'arm-' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_set_list = []
    val_set_list = []

    for i in range(num_datasets):
        train_set_list.append(
            datasets.Arm(
                args.data_dir[i],
                args.meta_dir[i],
                args.random_bg_dir,
                cams[0],
                args.anno_type[i],
                train=True,
                training_set_percentage=args.training_set_percentage[i],
                replace_bg=args.replace_bg))

        val_set_list.append(
            datasets.Arm(
                args.data_dir[i],
                args.meta_dir[i],
                args.random_bg_dir,
                cams[0],
                args.anno_type[i],
                train=False,
                training_set_percentage=args.training_set_percentage[i],
                scales=scales,
                multi_scale=args.multi_scale,
                ignore_invis_pts=args.ignore_invis_pts))

    # Data loading code
    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(datasets.Concat(
            datasets=train_set_list, ratio=args.ratio),
                                                   batch_size=args.train_batch,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

        print("No. minibatches in training set:{}".format(len(train_loader)))

    if args.multi_scale:  #multi scale testing
        args.test_batch = args.test_batch * len(scales)

    val_loader = torch.utils.data.DataLoader(datasets.Concat(
        datasets=val_set_list, ratio=None),
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    print("No. minibatches in validation set:{}".format(len(val_loader)))

    if args.evaluate:
        print('\nEvaluation only')
        # if not args.compute_3d:
        loss, acc = validate(val_loader, model, criterion, args.num_classes,
                             idx, args.save_result_dir, args.meta_dir,
                             args.anno_type, args.flip, args.evaluate, scales,
                             args.multi_scale, args.save_heatmap)

        if args.compute_3d:

            preds = []
            gts = []
            hit, d3_pred, file_name_list = d2tod3(
                data_dir=args.save_result_dir,
                meta_dir=args.meta_dir[0],
                cam_type=args.camera_type,
                pred_from_heatmap=False,
                em_test=False)

            # validate the 3d reconstruction accuracy

            with open(os.path.join(args.save_result_dir, 'd3_pred.json'),
                      'r') as f:
                obj = json.load(f)
                hit, d3_pred, file_name_list = obj['hit'], obj['d3_pred'], obj[
                    'file_name_list']

            for file_name in file_name_list:
                preds.append(d3_pred[file_name]['preds'])  #predicted x
                with open(os.path.join(args.data_dir[0], 'angles', file_name),
                          'r') as f:
                    gts.append(json.load(f))

            print('average error in angle: [base, elbow, ankle, wrist]:{}'.
                  format(d3_acc(preds, gts)))

        return

    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, idx, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc = validate(val_loader, model, criterion,
                                         args.num_classes, idx,
                                         args.save_result_dir, args.meta_dir,
                                         args.anno_type, args.flip,
                                         args.evaluate)

        #If concatenated dataset is used, re-random after each epoch
        train_loader.dataset.reset(), val_loader.dataset.reset()

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Пример #5
0
def main(args):
    global best_acc
    global idx

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    elif args.dataset == 'homes':
        idx = list(range(1, 71))
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](
        num_stacks=args.stacks,
        num_blocks=args.blocks,
        num_classes=16,  # read as 16, the nmodified to 70
        resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)
    #print(list(model.children())); exit(1)
    # define loss function (criterion) and optimizer
    #criterion = losses.JointsMSELoss().to(device)
    #criterion = torch.nn.BCEWithLogitsLoss().to(device)
    criterion = torch.nn.MSELoss().to(device)
    print('==> loss function: %s' % criterion.__str__())

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # update the final score layer and fine tuning => update number of output channels to 70
    #print(list(model.children())[-1].score[0]); exit(1)
    model_net = list(model.children())
    model_net[-1].score[0] = nn.Conv2d(256,
                                       70,
                                       kernel_size=(1, 1),
                                       stride=(1, 1)).to(device)
    model = nn.Sequential(*model_net)

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](**vars(args))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset](is_train=False,
                                                  is_valid=True,
                                                  **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # evaluation only
    if args.evaluate:
        print('\nEvaluation only')
        loss, acc, predictions = validate(val_loader, model, criterion,
                                          njoints, args.debug, args.flip)
        save_pred(predictions, checkpoint=args.checkpoint)
        return

    # train and eval
    df_loss = pd.DataFrame()
    train_epo_loss, val_epo_loss = [], []
    lr = args.lr
    best_vel = np.float('inf')
    for epoch in range(args.start_epoch, args.epochs + args.start_epoch):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss, train_acc, tel = train(train_loader, model, criterion,
                                           optimizer, args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc, predictions, vel = validate(
            val_loader, model, criterion, njoints, args.debug, args.flip)
        # save epoch loss to csv
        train_epo_loss += [tel]
        val_epo_loss += [vel]
        df_loss.assign(train=train_epo_loss,
                       val=val_epo_loss).to_csv('./loss.csv')

        # append logger file
        logger.append([
            epoch + 1, lr, train_loss, valid_loss,
            train_acc.item(),
            valid_acc.item()
        ])

        # remember best acc and save checkpoint
        #is_best = valid_acc > best_acc
        #best_acc = max(valid_acc, best_acc)
        is_best = vel < best_vel
        best_vel = min(best_vel, vel)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()