コード例 #1
0
ファイル: evaluate_seq_ce.py プロジェクト: RobinWu218/mcr2
    parser.add_argument('--data_dir',
                        default='./data/',
                        help='path to dataset')
    args = parser.parse_args()

    print("evaluate using label_batch: {}".format(args.label_batch))

    params = utils.load_params(args.model_dir)
    # get train features and labels
    train_transforms = tf.load_transforms('test')
    trainset = tf.load_trainset(params['data'],
                                train_transforms,
                                train=True,
                                path=args.data_dir)
    if 'lcr' in params.keys():  # supervised corruption case
        trainset = tf.corrupt_labels(trainset, params['lcr'], params['lcs'])
    new_labels = trainset.targets
    assert (trainset.num_classes %
            args.cpb == 0), "Number of classes not divisible by cpb"
    ## load model
    net, epoch = tf.load_checkpoint_ce(args.model_dir,
                                       trainset.num_classes,
                                       args.epoch,
                                       eval_=True,
                                       label_batch_id=args.label_batch)
    net = net.cuda().eval()

    classes = np.unique(trainset.targets)
    class_batch_num = trainset.num_classes // args.cpb
    class_batch_list = classes.reshape(class_batch_num, args.cpb)
コード例 #2
0
        lr = args.lr * 0.01
    elif epoch >= 200:
        lr = args.lr * 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


## Prepare for Training
if args.pretrain_dir is not None:
    net, _ = tf.load_checkpoint(args.pretrain_dir, args.pretrain_epo)
    utils.update_params(model_dir, args.pretrain_dir)
else:
    net = tf.load_architectures(args.arch, args.fd)
transforms = tf.load_transforms(args.transform)
trainset = tf.load_trainset(args.data, transforms, path=args.data_dir)
trainset = tf.corrupt_labels(trainset, args.lcr, args.lcs)
trainloader = DataLoader(trainset,
                         batch_size=args.bs,
                         drop_last=True,
                         num_workers=4)
criterion = MaximalCodingRateReduction(gam1=args.gam1,
                                       gam2=args.gam2,
                                       eps=args.eps)
optimizer = SGD(net.parameters(),
                lr=args.lr,
                momentum=args.mom,
                weight_decay=args.wd)

## Training
for epoch in range(args.epo):
    lr_schedule(epoch, optimizer)