示例#1
0
def par_crop(args, ann_base_path):
    """
    Dataset curation, crop data and transform the format of label
    Parameters
    ----------
    ann_base_path: str, Annotations base path
    """
    crop_path = os.path.join(args.download_dir,
                             './crop{:d}'.format(int(args.instance_size)))
    if not os.path.isdir(crop_path):
        makedirs(crop_path)
    sub_sets = sorted({'a', 'b', 'c', 'd', 'e'})
    for sub_set in sub_sets:
        sub_set_base_path = os.path.join(ann_base_path, sub_set)
        videos = sorted(os.listdir(sub_set_base_path))
        n_videos = len(videos)
        with futures.ProcessPoolExecutor(
                max_workers=args.num_threads) as executor:
            fs = [
                executor.submit(crop_video, args, sub_set, video, crop_path,
                                ann_base_path) for video in videos
            ]
            for i, f in enumerate(futures.as_completed(fs)):
                # Write progress to error so that it can be seen
                printProgress(i,
                              n_videos,
                              prefix=sub_set,
                              suffix='Done ',
                              barLength=40)
示例#2
0
def par_crop(args):
    """
    Dataset curation,crop data and transform the format of a label
    """
    crop_path = os.path.join(args.download_dir,
                             './crop{:d}'.format(args.instance_size))
    if not os.path.isdir(crop_path): makedirs(crop_path)
    VID_base_path = os.path.join(args.download_dir, './ILSVRC')
    ann_base_path = os.path.join(VID_base_path, 'Annotations/DET/train/')
    sub_sets = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i')
    for sub_set in sub_sets:
        sub_set_base_path = os.path.join(ann_base_path, sub_set)
        if 'a' == sub_set:
            xmls = sorted(
                glob.glob(os.path.join(sub_set_base_path, '*', '*.xml')))
        else:
            xmls = sorted(glob.glob(os.path.join(sub_set_base_path, '*.xml')))
        n_imgs = len(xmls)
        sub_set_crop_path = os.path.join(crop_path, sub_set)
        with futures.ProcessPoolExecutor(
                max_workers=args.num_threads) as executor:
            fs = [
                executor.submit(crop_xml, args, xml, sub_set_crop_path,
                                args.instance_size) for xml in xmls
            ]
            for i, f in enumerate(futures.as_completed(fs)):
                printProgress(i,
                              n_imgs,
                              prefix=sub_set,
                              suffix='Done ',
                              barLength=80)
示例#3
0
def build_rec_process(img_dir, train=False, num_thread=1):
    rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
    makedirs(rec_dir)
    prefix = 'train' if train else 'val'
    print('Building ImageRecord file for ' + prefix + ' ...')
    to_path = rec_dir

    # download lst file and im2rec script
    script_path = os.path.join(rec_dir, 'im2rec.py')
    script_url = 'https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py'
    download(script_url, script_path)

    lst_path = os.path.join(rec_dir, prefix + '.lst')
    lst_url = 'http://data.mxnet.io/models/imagenet/resnet/' + prefix + '.lst'
    download(lst_url, lst_path)

    # execution
    import sys
    cmd = [
        sys.executable,
        script_path,
        rec_dir,
        img_dir,
        '--recursive',
        '--pass-through',
        '--pack-label',
        '--num-thread',
        str(num_thread)
    ]
    subprocess.call(cmd)
    os.remove(script_path)
    os.remove(lst_path)
    print('ImageRecord file for ' + prefix + ' has been built!')
示例#4
0
def download_aug(path, overwrite=False):
    _AUG_DOWNLOAD_URLS = [(
        "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
        "7129e0a480c2d6afb02b517bb18ac54283bfaa35",
    )]
    makedirs(path)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url,
                            path=path,
                            overwrite=overwrite,
                            sha1_hash=checksum)
        # extract
        with tarfile.open(filename) as tar:
            tar.extractall(path=path)
            shutil.move(os.path.join(path, "benchmark_RELEASE"),
                        os.path.join(path, "VOCaug"))
            filenames = ["VOCaug/dataset/train.txt", "VOCaug/dataset/val.txt"]
            # generate trainval.txt
            with open(os.path.join(path, "VOCaug/dataset/trainval.txt"),
                      "w") as outfile:
                for fname in filenames:
                    fname = os.path.join(path, fname)
                    with open(fname) as infile:
                        for line in infile:
                            outfile.write(line)
示例#5
0
def crop_xml(args, xml, sub_set_crop_path, instance_size=511):
    """
    Dataset curation

    Parameters
    ----------
    xml: str , xml
    sub_set_crop_path: str, xml crop path
    instance_size: int, instance_size
    """
    cv2 = try_import_cv2()
    xmltree = ET.parse(xml)
    objects = xmltree.findall('object')

    frame_crop_base_path = os.path.join(sub_set_crop_path, xml.split('/')[-1].split('.')[0])
    if not os.path.isdir(frame_crop_base_path):
        makedirs(frame_crop_base_path)
    img_path = xml.replace('xml', 'JPEG').replace('Annotations', 'Data')
    im = cv2.imread(img_path)
    avg_chans = np.mean(im, axis=(0, 1))

    for id, object_iter in enumerate(objects):
        bndbox = object_iter.find('bndbox')
        bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text),
                int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)]
        z, x = crop_like_SiamFC(im, bbox, instance_size=instance_size, padding=avg_chans)
        cv2.imwrite(os.path.join(args.download_dir, frame_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(0, id)), z)
        cv2.imwrite(os.path.join(args.download_dir, frame_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(0, id)), x)
示例#6
0
def crop_coco(args):
    """
    Dataset curation,crop data and transform the format of label

    Parameters
    ----------
    ann_base_path: str, Annotations base path
    """
    crop_path = os.path.join(args.download_dir, './crop{:d}'.format(int(args.instance_size)))
    if not os.path.isdir(crop_path): makedirs(crop_path)

    for dataType in ['val2017', 'train2017']:
        set_crop_base_path = os.path.join(crop_path, dataType)
        set_img_base_path = os.path.join(args.download_dir, dataType)

        annFile = '{}/annotations/instances_{}.json'.format(args.download_dir, dataType)
        coco = COCO(annFile)
        n_imgs = len(coco.imgs)
        with futures.ProcessPoolExecutor(max_workers=args.num_threads) as executor:
            fs = [executor.submit(crop_img, coco.loadImgs(id)[0],
                                  coco.loadAnns(coco.getAnnIds(imgIds=id, iscrowd=None)),
                                  set_crop_base_path, set_img_base_path, args.instance_size) for id in coco.imgs]
            for i, f in enumerate(futures.as_completed(fs)):
                # Write progress to error so that it can be seen
                printProgress(i, n_imgs, prefix=dataType, suffix='Done ', barLength=40)
    print('done')
def build_rec_process(img_dir, train=False, num_thread=1):
    rec_dir = os.path.abspath(os.path.join(img_dir, '../rec'))
    makedirs(rec_dir)
    prefix = 'train' if train else 'val'
    print('Building ImageRecord file for ' + prefix + ' ...')
    to_path = rec_dir

    # download lst file and im2rec script
    script_path = os.path.join(rec_dir, 'im2rec.py')
    script_url = 'https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/im2rec.py'
    download(script_url, script_path)

    lst_path = os.path.join(rec_dir, prefix + '.lst')
    lst_url = 'http://data.mxnet.io/models/imagenet/resnet/' + prefix + '.lst'
    download(lst_url, lst_path)

    # execution
    import sys
    cmd = [
        sys.executable,
        script_path,
        rec_dir,
        img_dir,
        '--recursive',
        '--pass-through',
        '--pack-label',
        '--num-thread',
        str(num_thread),
        '--resize',
        '512'
    ]
    subprocess.call(cmd)
    os.remove(script_path)
    os.remove(lst_path)
    print('ImageRecord file for ' + prefix + ' has been built!')
示例#8
0
def crop_video(args, sub_set, video, crop_path, ann_base_path):
    """
    Dataset curation

    Parameters
    ----------
    sub_set: str , sub_set
    video: str, video number
    crop_path: str, crop_path
    ann_base_path: str, Annotations base path
    """
    cv2 = try_import_cv2()
    video_crop_base_path = os.path.join(crop_path, sub_set, video)
    if not os.path.isdir(video_crop_base_path):
        makedirs(video_crop_base_path)
    sub_set_base_path = os.path.join(ann_base_path, sub_set)
    xmls = sorted(glob.glob(os.path.join(sub_set_base_path, video, '*.xml')))
    for xml in xmls:
        xmltree = ET.parse(xml)
        objects = xmltree.findall('object')
        objs = []
        filename = xmltree.findall('filename')[0].text
        im = cv2.imread(xml.replace('xml', 'JPEG').replace('Annotations', 'Data'))
        avg_chans = np.mean(im, axis=(0, 1))
        for object_iter in objects:
            trackid = int(object_iter.find('trackid').text)
            bndbox = object_iter.find('bndbox')
            bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text),
                    int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)]
            z, x = crop_like_SiamFC(im, bbox, instance_size=args.instance_size, padding=avg_chans)
            cv2.imwrite(os.path.join(args.download_dir, video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(filename), trackid)), z)
            cv2.imwrite(os.path.join(args.download_dir, video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(filename), trackid)), x)
示例#9
0
def download_ade(path, overwrite=False):
    _AUG_DOWNLOAD_URLS = [
        ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
        ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)
示例#10
0
def download_ade(path, overwrite=False):
    _AUG_DOWNLOAD_URLS = [
        ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
        ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)
示例#11
0
def download_VID(args, overwrite=False):
    """download VID dataset and Unzip to download_dir"""
    _DOWNLOAD_URLS = [ 
    ('http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz',
    '077dbdea4dff1853edd81b04fa98e19392287ca3'),
    ]
    if not os.path.isdir(args.download_dir):
        makedirs(args.download_dir)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=args.download_dir, overwrite=overwrite, sha1_hash=checksum)
        print('dataset is unziping')
        with tarfile.open(filename) as tar:
           tar.extractall(path=args.download_dir)
def download_voc(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
         '34ed68851bce2a36e2a223fa52c661d592c66b3c'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
         '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
         '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with tarfile.open(filename) as tar:
            tar.extractall(path=path)
示例#13
0
def download_voc(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
         '34ed68851bce2a36e2a223fa52c661d592c66b3c'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
         '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
         '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with tarfile.open(filename) as tar:
            tar.extractall(path=path)
示例#14
0
def download_det(args, overwrite=False):
    """download DET dataset and Unzip to download_dir"""
    _DOWNLOAD_URLS = [
    ('http://image-net.org/image/ILSVRC2015/ILSVRC2015_DET.tar.gz',
    'cbf602d89f2877fa8843392a1ffde03450a18d38'),
    ]
    if not os.path.isdir(args.download_dir):
        makedirs(args.download_dir)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=args.download_dir, overwrite=overwrite, sha1_hash=checksum)
        print(' dataset has already download completed')
        with tarfile.open(filename) as tar:
            tar.extractall(path=args.download_dir)
    if os.path.isdir(os.path.join(args.download_dir, 'ILSVRC2015')):
        os.rename(os.path.join(args.download_dir, 'ILSVRC2015'), os.path.join(args.download_dir, 'ILSVRC'))
示例#15
0
def download_coco(args, overwrite=False):
    """download COCO dataset and Unzip to download_dir"""
    _DOWNLOAD_URLS = [
        ('http://images.cocodataset.org/zips/train2017.zip',
         '10ad623668ab00c62c096f0ed636d6aff41faca5'),
        ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
         '8551ee4bb5860311e79dace7e79cb91e432e78b3'),
        ('http://images.cocodataset.org/zips/val2017.zip',
         '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
    ]
    if not os.path.isdir(args.download_dir):
        makedirs(args.download_dir)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=args.download_dir, overwrite=overwrite, sha1_hash=checksum)
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=args.download_dir)
示例#16
0
def download_coco(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://images.cocodataset.org/zips/val2017.zip',
         '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
        ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
         '8551ee4bb5860311e79dace7e79cb91e432e78b3')
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url,
                            path=path,
                            overwrite=overwrite,
                            sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=path)
示例#17
0
def download_city(path, overwrite=False):
    _CITY_DOWNLOAD_URLS = [
        ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
        ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for filename, checksum in _CITY_DOWNLOAD_URLS:
        if not check_sha1(filename, checksum):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(filename))
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)
        print("Extracted", filename)
示例#18
0
def download_city(path, overwrite=False):
    _CITY_DOWNLOAD_URLS = [('gtFine_trainvaltest.zip',
                            '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
                           ('leftImg8bit_trainvaltest.zip',
                            '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for filename, checksum in _CITY_DOWNLOAD_URLS:
        if not check_sha1(filename, checksum):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(filename))
        # extract
        with zipfile.ZipFile(filename, "r") as zip_ref:
            zip_ref.extractall(path=path)
        print("Extracted", filename)
示例#19
0
def download_aug(path, overwrite=False):
    _AUG_DOWNLOAD_URLS = [
        ('http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz', '7129e0a480c2d6afb02b517bb18ac54283bfaa35')]
    makedirs(path)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with tarfile.open(filename) as tar:
            tar.extractall(path=path)
            shutil.move(os.path.join(path, 'benchmark_RELEASE'),
                        os.path.join(path, 'VOCaug'))
            filenames = ['VOCaug/dataset/train.txt', 'VOCaug/dataset/val.txt']
            # generate trainval.txt
            with open(os.path.join(path, 'VOCaug/dataset/trainval.txt'), 'w') as outfile:
                for fname in filenames:
                    fname = os.path.join(path, fname)
                    with open(fname) as infile:
                        for line in infile:
                            outfile.write(line)
示例#20
0
def download_mhp_v1(path, overwrite=False):
    try_import_html5lib()
    gdf = try_import_gdfDownloader()
    downloader = gdf.googleDriveFileDownloader()

    file_link = 'https://drive.google.com/uc?id=1hTS8QJBuGdcppFAr_bvW2tsD9hW_ptr5&export=download'
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    filename = os.path.join(download_dir, 'LV-MHP-v1.zip')

    # donwload MHP_v1 zip file with Google-Drive-File-Downloader
    downloader.downloadFile(file_link)

    # move zip file to download_dir
    shutil.move('./LV-MHP-v1.zip', filename)

    # extract
    with zipfile.ZipFile(filename, "r") as zip_ref:
        zip_ref.extractall(path=path)
示例#21
0
def download_coco(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://images.cocodataset.org/zips/train2017.zip',
         '10ad623668ab00c62c096f0ed636d6aff41faca5'),
        ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
         '8551ee4bb5860311e79dace7e79cb91e432e78b3'),
        ('http://images.cocodataset.org/zips/val2017.zip',
         '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
        # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
         # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
        # test2017.zip, for those who want to attend the competition.
        # ('http://images.cocodataset.org/zips/test2017.zip',
        #  '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=path)
示例#22
0
def download_coco(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://images.cocodataset.org/zips/train2017.zip',
         '10ad623668ab00c62c096f0ed636d6aff41faca5'),
        ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
         '8551ee4bb5860311e79dace7e79cb91e432e78b3'),
        ('http://images.cocodataset.org/zips/val2017.zip',
         '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
        # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
         # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
        # test2017.zip, for those who want to attend the competition.
        # ('http://images.cocodataset.org/zips/test2017.zip',
        #  '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=path)
示例#23
0
def download_otb(args, overwrite=False):
    """download otb2015 dataset and Unzip to download_dir"""
    _DOWNLOAD_URLS = 'http://cvlab.hanyang.ac.kr/tracker_benchmark/seq/'
    if not os.path.isdir(args.download_dir):
        makedirs(args.download_dir)
    for per_otb50 in otb50:
        url = os.path.join(_DOWNLOAD_URLS, per_otb50 + '.zip')
        filename = download(url, path=args.download_dir, overwrite=overwrite)
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=args.download_dir)
    for per_otb100 in otb100:
        url = os.path.join(_DOWNLOAD_URLS, per_otb100 + '.zip')
        filename = download(url, path=args.download_dir, overwrite=overwrite)
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=args.download_dir)

    os.rename(os.path.join(args, 'Jogging'), os.path.join(args, 'Jogging-1'))
    os.rename(os.path.join(args, 'Jogging'), os.path.join(args, 'Jogging-2'))
    os.rename(os.path.join(args, 'Skating2'), os.path.join(args, 'Skating2-1'))
    os.rename(os.path.join(args, 'Skating2'), os.path.join(args, 'Skating2-2'))
    os.rename(os.path.join(args, ' Human4'), os.path.join(args, 'Human4-2'))
示例#24
0
def download_wider(path, overwrite=False):
    _CITY_DOWNLOAD_URLS = [
        ('WIDER_train.zip', 'ea80d8614a81ffaf8b3830a2a6807676ca666846'),
        ('WIDER_val.zip', '3643b3045a491b402b46a22e5ccfe1fdcf3d6c68'),
        ('wider_face_split.zip', 'd4949bbb444f2852e84373b0390f6ba6241be931'),
        ('eval_tools.zip', 'bcb6abdc19dac0f853f75b5d03396d5120aef3dc'),
        # ('WIDER_test.zip', 'f7fa64455c1262150b0dc75985b03a94bf655d92'),
        # ('Submission_example.zip', 'eb124c3a3e90ea03cbc60c28b189ba632dc95444'),
    ]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for filename, checksum in _CITY_DOWNLOAD_URLS:
        filename = os.path.join(download_dir, filename)
        if not check_sha1(filename, checksum):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(filename))
        # extract
        with zipfile.ZipFile(filename, "r") as zip_ref:
            zip_ref.extractall(path=path)
        print("Extracted", filename)
示例#25
0
def main():
    args = parse_args()
    name = "Market-1501-v15.09.15"
    url = "http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/{name}.zip".format(
        name=name)
    root = osp.expanduser(args.download_dir)
    makedirs(root)
    fpath = osp.join(root, name + '.zip')
    exdir = osp.join(root, name)
    if not osp.exists(fpath) and not osp.isdir(exdir) and args.no_download:
        raise ValueError(
            ('{} dataset archive not found, make sure it is present.'
             ' Or you should not disable "--no-download" to grab it'.format(
                 fpath)))
    # Download by default
    if not args.no_download:
        print('Downloading dataset')
        download(url, fpath, overwrite=False)
        print('Dataset downloaded')
    # Extract dataset if fresh copy downloaded or existing archive is yet to be extracted
    if not args.no_download or not osp.isdir(exdir):
        extract(fpath, root)
        make_list(exdir)
示例#26
0
def crop_img(img, anns, set_crop_base_path, set_img_base_path, instance_size=511):
    """
    Dataset curation

    Parameters
    ----------
    img: dic, img
    anns: str, video number
    set_crop_base_path: str, crop result path
    set_img_base_path: str, ori image path
    """
    frame_crop_base_path = os.path.join(set_crop_base_path, img['file_name'].split('/')[-1].split('.')[0])
    if not os.path.isdir(frame_crop_base_path): makedirs(frame_crop_base_path)
    cv2 = try_import_cv2()
    im = cv2.imread('{}/{}'.format(set_img_base_path, img['file_name']))
    avg_chans = np.mean(im, axis=(0, 1))
    for trackid, ann in enumerate(anns):
        rect = ann['bbox']
        bbox = [rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]]
        if rect[2] <= 0 or rect[3] <= 0:
            continue
        z, x = crop_like_SiamFC(im, bbox, instance_size=instance_size, padding=avg_chans)
        cv2.imwrite(os.path.join(args.download_dir, frame_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(0, trackid)), z)
        cv2.imwrite(os.path.join(args.download_dir, frame_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(0, trackid)), x)
示例#27
0
def download_vg(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip',
         'a055367f675dd5476220e9b93e4ca9957b024b94'),
        ('https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip',
         '2add3aab77623549e92b7f15cda0308f50b64ecf'),
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url,
                            path=path,
                            overwrite=overwrite,
                            sha1_hash=checksum)
        # extract
        if filename.endswith('zip'):
            with zipfile.ZipFile(filename) as zf:
                zf.extractall(path=path)
    # move all images into folder `VG_100K`
    vg_100k_path = os.path.join(path, 'VG_100K')
    vg_100k_2_path = os.path.join(path, 'VG_100K_2')
    files_2 = os.listdir(vg_100k_2_path)
    for fl in files_2:
        shutil.move(os.path.join(vg_100k_2_path, fl),
                    os.path.join(vg_100k_path, fl))
示例#28
0
         # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
        # test2017.zip, for those who want to attend the competition.
        # ('http://images.cocodataset.org/zips/test2017.zip',
        #  '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename) as zf:
            zf.extractall(path=path)

if __name__ == '__main__':
    args = parse_args()
    path = os.path.expanduser(args.download_dir)
    if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \
        or not os.path.isdir(os.path.join(path, 'val2017')) \
        or not os.path.isdir(os.path.join(path, 'annotations')):
        if args.no_download:
            raise ValueError(('{} is not a valid directory, make sure it is present.'
                              ' Or you should not disable "--no-download" to grab it'.format(path)))
        else:
            download_coco(path, overwrite=args.overwrite)

    # make symlink
    makedirs(os.path.expanduser('~/.mxnet/datasets'))
    if os.path.isdir(_TARGET_DIR):
        os.remove(_TARGET_DIR)
    os.symlink(path, _TARGET_DIR)
    try_import_pycocotools()
示例#29
0
        self.use_train_patterns = False
        self.freeze_patterns = ''
        self.freeze_lr_mult = 0.01 #set freezed base layer lr = self.lr * self.freeze_lr_mult
        self.use_mult = False
        self.clip_grad = 40
        self.log_interval = 10
        self.lr_mode = 'step'        
        self.resume_epoch = 0
        self.resume_params = ''#os.path.join('logs/param_rgb_resnet18_v1b_k400_ucf101','0.8620-ucf101-resnet18_v1b_k400_ucf101-082-best.params')
        self.resume_states = ''#os.path.join('logs/param_rgb_resnet18_v1b_k400_ucf101','0.8620-ucf101-resnet18_v1b_k400_ucf101-082-best.states')
        self.reshape_type = 'tsn' #mxc3d c3d tsn tsn_newlength
      

opt = config()

makedirs(opt.save_dir)

filehandler = logging.FileHandler(os.path.join(opt.save_dir, opt.logging_file))
streamhandler = logging.StreamHandler()
logger = logging.getLogger('')
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
logger.info(opt)
    
# number of GPUs to use
num_gpus = opt.num_gpus
ctx = [mx.gpu(i) for i in range(num_gpus)]
#ctx = [mx.gpu(1)]

# Get the model 
示例#30
0
def parse_args():
    """Training Options for Segmentation Experiments"""
    parser = argparse.ArgumentParser(description='MXNet Gluon Segmentation')
    parser.add_argument('--host', type=str, default='xxx',
                        help='xxx is a place holder')
    parser.add_argument('--model', type=str, default='ResFPN',
                        help='model name: ResNetFPN, ResUNet')
    parser.add_argument('--fuse-mode', type=str, default='AsymBi',
                        help='DirectAdd, Concat, SK, BiLocal, BiGlobal, AsymBi, '
                             'TopDownGlobal, TopDownLocal')
    parser.add_argument('--tiny', action='store_true', default=False,
                        help='evaluation only')
    parser.add_argument('--blocks', type=int, default=3,
                        help='block num in each stage')
    parser.add_argument('--channel-times', type=int, default=1,
                        help='times of channel width')
    parser.add_argument('--dataset', type=str, default='DENTIST',
                        help='dataset name: DENTIST, Iceberg, StopSign')
    parser.add_argument('--workers', type=int, default=48,
                        metavar='N', help='dataloader threads')
    parser.add_argument('--base-size', type=int, default=512,
                        help='base image size')
    parser.add_argument('--iou-thresh', type=float, default=0.5,
                        help='iou-thresh')
    parser.add_argument('--crop-size', type=int, default=480,
                        help='crop image size')
    parser.add_argument('--train-split', type=str, default='trainval',
                        help='dataset train split (default: train)')
    parser.add_argument('--val-split', type=str, default='test',
                        help='dataset val split (default: val)')
    # training hyper params
    parser.add_argument('--epochs', type=int, default=300, metavar='N',
                        help='number of epochs to train (default: 110)')
    parser.add_argument('--start_epoch', type=int, default=0,
                        metavar='N', help='start epochs (default:0)')
    parser.add_argument('--batch-size', type=int, default=8,
                        metavar='N', help='input batch size for \
                        training (default: 16)')
    parser.add_argument('--test-batch-size', type=int, default=8,
                        metavar='N', help='input batch size for \
                        testing (default: 32)')
    parser.add_argument('--optimizer', type=str, default='adagrad',
                        help='sgd, adam, adagrad')
    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--lr-decay', type=float, default=0.1,
                        help='decay rate of learning rate. default is 0.1.')
    parser.add_argument('--gamma', type=int, default=2,
                        help='gamma for Focal Soft IoU Loss')
    parser.add_argument('--momentum', type=float, default=0.9,
                        metavar='M', help='momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=1e-4,
                        metavar='M', help='w-decay (default: 1e-4)')
    parser.add_argument('--no-wd', action='store_true',
                        help='whether to remove weight decay on bias, \
                        and beta/gamma for batchnorm layers.')
    parser.add_argument('--score-thresh', type=float, default=0.5,
                        help='score-thresh')
    parser.add_argument('--warmup-epochs', type=int, default=0,
                        help='number of warmup epochs.')
    # cuda and logging
    parser.add_argument('--no-cuda', action='store_true', default=
                        False, help='disables CUDA training')
    parser.add_argument('--gpus', type=str, default='0',
                        help='Training with GPUs, you can specify 1,3 for example.')
    parser.add_argument('--kvstore', type=str, default='device',
                        help='kvstore to use for trainer/module.')
    parser.add_argument('--dtype', type=str, default='float32',
                        help='data type for training. default is float32')
    parser.add_argument('--wd', type=float, default=0.0001,
                        help='weight decay rate. default is 0.0001.')
    parser.add_argument('--log-interval', type=int, default=50,
                        help='Number of batches to wait before logging.')
    # checking point
    parser.add_argument('--resume', type=str, default=None,
                        help='put the path to resuming file if needed')
    parser.add_argument('--colab', action='store_true', default=
                        False, help='whether using colab')
    parser.add_argument('--save-dir', type=str, default=None,
                        help='directory of saved models')
    # evaluation only
    parser.add_argument('--eval', action='store_true', default= False,
                        help='evaluation only')
    parser.add_argument('--no-val', action='store_true', default= False,
                            help='skip validation during training')
    parser.add_argument('--metric', type=str, default='mAP',
                        help='F1, IoU, mAP')
    parser.add_argument('--logging-file', type=str, default='train.log',
                        help='name of training log file')
    parser.add_argument('--summary', action='store_true',
                        help='print parameters')
    # synchronized Batch Normalization
    parser.add_argument('--syncbn', action='store_true', default= False,
                        help='using Synchronized Cross-GPU BatchNorm')
    # the parser
    args = parser.parse_args()
    # handle contexts
    if args.no_cuda or (len(mx.test_utils.list_gpus()) == 0):
        print('Using CPU')
        args.kvstore = 'local'
        args.ctx = [mx.cpu(0)]
    else:
        args.ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
        print('Number of GPUs:', len(args.ctx))

    # logging and checkpoint saving
    if args.save_dir is None:
        args.save_dir = "runs/%s/%s/" % (args.dataset, args.model)
    makedirs(args.save_dir)

    # Synchronized BatchNorm
    args.norm_layer = mx.gluon.contrib.nn.SyncBatchNorm if args.syncbn \
        else mx.gluon.nn.BatchNorm
    args.norm_kwargs = {'num_devices': len(args.ctx)} if args.syncbn else {}
    print(args)
    return args
示例#31
0
def parse_args():
    """Training Options for Semantic Segmentation Experiments"""
    parser = argparse.ArgumentParser(
        description='MXNet Gluon Semantic Segmentation')
    # model and dataset
    parser.add_argument('--model',
                        type=str,
                        default='fcn',
                        help='model name (default: fcn)')
    parser.add_argument('--model-zoo',
                        type=str,
                        default=None,
                        help='evaluating on model zoo model')
    parser.add_argument('--pretrained',
                        action="store_true",
                        help='whether to use pretrained params')
    parser.add_argument('--backbone',
                        type=str,
                        default='resnet50',
                        help='backbone name (default: resnet50)')
    parser.add_argument('--dataset',
                        type=str,
                        default='pascal',
                        help='dataset name (default: pascal)')
    parser.add_argument('--workers',
                        type=int,
                        default=16,
                        metavar='N',
                        help='dataloader threads')
    parser.add_argument('--base-size',
                        type=int,
                        default=520,
                        help='base image size')
    parser.add_argument('--crop-size',
                        type=int,
                        default=480,
                        help='crop image size')
    parser.add_argument('--train-split',
                        type=str,
                        default='train',
                        help='dataset train split (default: train)')
    # training hyper params
    parser.add_argument('--aux',
                        action='store_true',
                        default=False,
                        help='Auxiliary loss')
    parser.add_argument('--aux-weight',
                        type=float,
                        default=0.5,
                        help='auxiliary loss weight')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--start_epoch',
                        type=int,
                        default=0,
                        metavar='N',
                        help='start epochs (default:0)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=16,
                        metavar='N',
                        help='input batch size for \
                        training (default: 16)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=16,
                        metavar='N',
                        help='input batch size for \
                        testing (default: 16)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='sgd',
                        help='optimizer (default: sgd)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        metavar='LR',
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--warmup-epochs',
                        type=int,
                        default=0,
                        help='number of warmup epochs.')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='momentum (default: 0.9)')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=1e-4,
                        metavar='M',
                        help='w-decay (default: 1e-4)')
    parser.add_argument('--no-wd',
                        action='store_true',
                        help='whether to remove weight decay on bias, \
                        and beta/gamma for batchnorm layers.')
    parser.add_argument('--mode',
                        type=str,
                        default=None,
                        help='whether to turn on model hybridization')
    # cuda and distribute
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--ngpus',
                        type=int,
                        default=len(mx.test_utils.list_gpus()),
                        help='number of GPUs (default: 4)')
    parser.add_argument('--kvstore',
                        type=str,
                        default='device',
                        help='kvstore to use for trainer/module.')
    parser.add_argument('--dtype',
                        type=str,
                        default='float32',
                        help='data type for training. default is float32')
    # checking point
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkname',
                        type=str,
                        default='default',
                        help='set the checkpoint name')
    parser.add_argument('--save-dir',
                        type=str,
                        default=None,
                        help='directory of saved models')
    parser.add_argument('--log-interval',
                        type=int,
                        default=20,
                        help='Number of batches to wait before logging.')
    parser.add_argument('--logging-file',
                        type=str,
                        default='train.log',
                        help='name of training log file')
    # evaluation only
    parser.add_argument('--eval',
                        action='store_true',
                        default=False,
                        help='evaluation only')
    parser.add_argument('--no-val',
                        action='store_true',
                        default=False,
                        help='skip validation during training')
    # synchronized Batch Normalization
    parser.add_argument('--syncbn',
                        action='store_true',
                        default=False,
                        help='using Synchronized Cross-GPU BatchNorm')
    # the parser
    args = parser.parse_args()

    # handle contexts
    if args.no_cuda:
        print('Using CPU')
        args.kvstore = 'local'
        args.ctx = [mx.cpu(0)]
    else:
        print('Number of GPUs:', args.ngpus)
        assert args.ngpus > 0, 'No GPUs found, please enable --no-cuda for CPU mode.'
        args.ctx = [mx.gpu(i) for i in range(args.ngpus)]

    if 'psp' in args.model or 'deeplab' in args.model:
        assert args.crop_size % 8 == 0, (
            'For PSPNet and DeepLabV3 model families, '
            'we only support input crop size as multiples of 8.')

    # logging and checkpoint saving
    if args.save_dir is None:
        args.save_dir = "runs/%s/%s/%s/" % (args.dataset, args.model,
                                            args.backbone)
    makedirs(args.save_dir)

    # Synchronized BatchNorm
    args.norm_layer = mx.gluon.contrib.nn.SyncBatchNorm if args.syncbn \
        else mx.gluon.nn.BatchNorm
    args.norm_kwargs = {'num_devices': args.ngpus} if args.syncbn else {}
    return args
model_name = opt.model
if model_name.startswith('cifar_wideresnet'):
    kwargs = {'classes': classes,
              'drop_rate': opt.drop_rate}
else:
    kwargs = {'classes': classes}
net = get_model(model_name, **kwargs)
model_name += '_mixup'
if opt.resume_from:
    net.load_parameters(opt.resume_from, ctx = context)
optimizer = 'nag'

save_period = opt.save_period
if opt.save_dir and save_period:
    save_dir = opt.save_dir
    makedirs(save_dir)
else:
    save_dir = ''
    save_period = 0

plot_name = opt.save_plot_dir

logging_handlers = [logging.StreamHandler()]
if opt.logging_dir:
    logging_dir = opt.logging_dir
    makedirs(logging_dir)
    logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))

logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
logging.info(opt)
示例#33
0
    train_images = f.readlines()

val_images_file = os.path.join(path, 'labels/validate' + str(split) + '.txt')
with open(val_images_file, 'r') as f:
    val_images = f.readlines()

test_images_file = os.path.join(path, 'labels/test' + str(split) + '.txt')
with open(test_images_file, 'r') as f:
    test_images = f.readlines()

# Create directories
src_path = os.path.join(path, 'images')
train_path = os.path.join(path, 'train')
val_path = os.path.join(path, 'val')
test_path = os.path.join(path, 'test')
makedirs(train_path)
makedirs(val_path)
makedirs(test_path)

labels = sorted(os.listdir(src_path))

for l in labels:
    makedirs(os.path.join(train_path, l))
    makedirs(os.path.join(val_path, l))
    makedirs(os.path.join(test_path, l))

# Copy files to corresponding directory
for im in train_images:
    im_path = im.replace('images/', '').strip('\n')
    shutil.copy(os.path.join(src_path, im_path),
                os.path.join(train_path, im_path))
示例#34
0
def main():
    opt = parse_args()

    bps.init()

    gpu_name = subprocess.check_output(
        ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
    gpu_name = gpu_name.decode('utf8').split('\n')[-2]
    gpu_name = '-'.join(gpu_name.split())
    filename = "cifar100-%d-%s-%s.log" % (bps.size(), gpu_name,
                                          opt.logging_file)
    filehandler = logging.FileHandler(filename)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 100

    num_gpus = opt.num_gpus
    # batch_size *= max(1, num_gpus)
    context = mx.gpu(bps.local_rank()) if num_gpus > 0 else mx.cpu(
        bps.local_rank())
    num_workers = opt.num_workers
    nworker = bps.size()
    rank = bps.rank()

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    num_batches = 50000 // (opt.batch_size * nworker)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=opt.warmup_lr,
                    target_lr=opt.lr * nworker / bps.local_size(),
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler('step',
                    base_lr=opt.lr * nworker / bps.local_size(),
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)

    if opt.compressor:
        optimizer = 'sgd'
    else:
        optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    # from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py
    CIFAR100_TRAIN_MEAN = [
        0.5070751592371323, 0.48654887331495095, 0.4409178433670343
    ]
    CIFAR100_TRAIN_STD = [
        0.2673342858792401, 0.2564384629170883, 0.27615047132568404
    ]

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=True).shard(nworker, rank).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR100(
            train=False).shard(nworker, rank).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        params = net.collect_params()

        compression_params = {
            "compressor": opt.compressor,
            "ef": opt.ef,
            "momentum": opt.compress_momentum,
            "scaling": opt.onebit_scaling,
            "k": opt.k,
            "fp16": opt.fp16_pushpull
        }

        optimizer_params = {
            'lr_scheduler': lr_scheduler,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

        trainer = bps.DistributedTrainer(params,
                                         optimizer,
                                         optimizer_params,
                                         compression_params=compression_params)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        best_val_score = 0
        bps.byteps_declare_tensor("acc")
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, train_acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, train_acc = train_metric.get()
            throughput = int(batch_size * nworker * i / (time.time() - tic))

            logger.info(
                '[Epoch %d] speed: %d samples/sec\ttime cost: %f lr=%f' %
                (epoch, throughput, time.time() - tic, trainer.learning_rate))

            name, val_acc = test(ctx, val_data)
            acc = mx.nd.array([train_acc, val_acc], ctx=ctx[0])
            bps.byteps_push_pull(acc, name="acc", is_average=False)
            acc /= bps.size()
            train_acc, val_acc = acc[0].asscalar(), acc[1].asscalar()
            if bps.rank() == 0:
                logger.info('[Epoch %d] training: %s=%f' %
                            (epoch, name, train_acc))
                logger.info('[Epoch %d] validation: %s=%f' %
                            (epoch, name, val_acc))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar100-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar100-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
示例#35
0
# %% define parameters
epoch = opt.epoch
epoch_start = 0
batch_size = opt.batch
lr = opt.lr
nz = opt.nz
imageSize = opt.imageSize
should_save_checkpoint = opt.save_checkpoint
save_per_epoch = opt.save_per_epoch
save_dir = opt.save_dir
pred_per_gen = opt.pred_per_gen
should_use_val = opt.validation
dataset = opt.dataset
dataset_loader = getattr(gan_datasets, 'load_{}'.format(dataset))
fix_noise_dir = 'saved/fixednoise'
makedirs(fix_noise_dir)

CTX = mx.gpu() if opt.cuda else mx.cpu()
logger.info('Will use {}'.format(CTX))

# %% define dataloader
logger.info("Prepare data")
# noinspection PyTypeChecker
tfs_train = gluon.data.vision.transforms.Compose([
    gluon.data.vision.transforms.Resize(size=(imageSize, imageSize),
                                        interpolation=2),
    # gluon.data.vision.transforms.RandomFlipLeftRight(),
    gluon.data.vision.transforms.RandomSaturation(0.001),
    gluon.data.vision.transforms.ToTensor(),
    gluon.data.vision.transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                           std=(0.5, 0.5, 0.5))
示例#36
0
def main():
    opt = parse_args()

    makedirs(opt.save_dir)

    filehandler = logging.FileHandler(
        os.path.join(opt.save_dir, opt.logging_file))
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)

    sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5, verbose=False)

    if opt.kvstore is not None:
        kv = mx.kvstore.create(opt.kvstore)
        logger.info(
            'Distributed training with %d workers and current rank is %d' %
            (kv.num_workers, kv.rank))
    if opt.use_amp:
        amp.init()

    batch_size = opt.batch_size
    classes = opt.num_classes

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    logger.info('Total batch size is set to %d on %d GPUs' %
                (batch_size, num_gpus))
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]

    optimizer = 'sgd'
    if opt.clip_grad > 0:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum,
            'clip_gradient': opt.clip_grad
        }
    else:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    model_name = opt.model
    net = get_model(name=model_name,
                    nclass=classes,
                    pretrained=opt.use_pretrained,
                    use_tsn=opt.use_tsn,
                    num_segments=opt.num_segments,
                    partial_bn=opt.partial_bn)
    net.cast(opt.dtype)
    net.collect_params().reset_ctx(context)
    logger.info(net)

    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    if opt.kvstore is not None:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger, kv)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger)

    num_batches = len(train_data)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])
    optimizer_params['lr_scheduler'] = lr_scheduler

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    def test(ctx, val_data, kvstore=None):
        acc_top1.reset()
        acc_top5.reset()
        L = gluon.loss.SoftmaxCrossEntropyLoss()
        num_test_iter = len(val_data)
        val_loss_epoch = 0
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = []
            for _, X in enumerate(data):
                X = X.reshape((-1, ) + X.shape[2:])
                pred = net(X.astype(opt.dtype, copy=False))
                outputs.append(pred)

            loss = [
                L(yhat, y.astype(opt.dtype, copy=False))
                for yhat, y in zip(outputs, label)
            ]

            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

            val_loss_epoch += sum([l.mean().asscalar()
                                   for l in loss]) / len(loss)

            if opt.log_interval and not (i + 1) % opt.log_interval:
                logger.info('Batch [%04d]/[%04d]: evaluated' %
                            (i, num_test_iter))

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        val_loss = val_loss_epoch / num_test_iter

        if kvstore is not None:
            top1_nd = nd.zeros(1)
            top5_nd = nd.zeros(1)
            val_loss_nd = nd.zeros(1)
            kvstore.push(111111, nd.array(np.array([top1])))
            kvstore.pull(111111, out=top1_nd)
            kvstore.push(555555, nd.array(np.array([top5])))
            kvstore.pull(555555, out=top5_nd)
            kvstore.push(999999, nd.array(np.array([val_loss])))
            kvstore.pull(999999, out=val_loss_nd)
            top1 = top1_nd.asnumpy() / kvstore.num_workers
            top5 = top5_nd.asnumpy() / kvstore.num_workers
            val_loss = val_loss_nd.asnumpy() / kvstore.num_workers

        return (top1, top5, val_loss)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if opt.partial_bn:
            train_patterns = None
            if 'inceptionv3' in opt.model:
                train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
            else:
                logger.info(
                    'Current model does not support partial batch normalization.'
                )

            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)
        else:
            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)

        if opt.accumulate > 1:
            params = [
                p for p in net.collect_params().values()
                if p.grad_req != 'null'
            ]
            for p in params:
                p.grad_req = 'add'

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.use_amp:
            amp.init_trainer(trainer)

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        best_val_score = 0
        lr_decay_count = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            btic = time.time()
            num_train_iter = len(train_data)
            train_loss_epoch = 0
            train_loss_iter = 0

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = []
                    for _, X in enumerate(data):
                        X = X.reshape((-1, ) + X.shape[2:])
                        pred = net(X.astype(opt.dtype, copy=False))
                        outputs.append(pred)
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]

                    if opt.use_amp:
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                    else:
                        ag.backward(loss)

                if opt.accumulate > 1 and (i + 1) % opt.accumulate == 0:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers *
                                     opt.accumulate)
                    else:
                        trainer.step(batch_size * opt.accumulate)
                        net.collect_params().zero_grad()
                else:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers)
                    else:
                        trainer.step(batch_size)

                train_metric.update(label, outputs)
                train_loss_iter = sum([l.mean().asscalar()
                                       for l in loss]) / len(loss)
                train_loss_epoch += train_loss_iter

                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(tag='train_acc_top1_iter',
                              value=train_metric_score * 100,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='train_loss_iter',
                              value=train_loss_iter,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='learning_rate_iter',
                              value=trainer.learning_rate,
                              global_step=epoch * num_train_iter + i)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    logger.info(
                        'Epoch[%03d] Batch [%04d]/[%04d]\tSpeed: %f samples/sec\t %s=%f\t loss=%f\t lr=%f'
                        % (epoch, i, num_train_iter,
                           batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score * 100, train_loss_epoch /
                           (i + 1), trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))
            mx.ndarray.waitall()

            if opt.kvstore is not None and epoch == opt.resume_epoch:
                kv.init(111111, nd.zeros(1))
                kv.init(555555, nd.zeros(1))
                kv.init(999999, nd.zeros(1))

            if opt.kvstore is not None:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data, kv)
            else:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)

            logger.info('[Epoch %03d] training: %s=%f\t loss=%f' %
                        (epoch, train_metric_name, train_metric_score * 100,
                         train_loss_epoch / num_train_iter))
            logger.info('[Epoch %03d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info(
                '[Epoch %03d] validation: acc-top1=%f acc-top5=%f loss=%f' %
                (epoch, acc_top1_val * 100, acc_top5_val * 100, loss_val))

            sw.add_scalar(tag='train_loss_epoch',
                          value=train_loss_epoch / num_train_iter,
                          global_step=epoch)
            sw.add_scalar(tag='val_loss_epoch',
                          value=loss_val,
                          global_step=epoch)
            sw.add_scalar(tag='val_acc_top1_epoch',
                          value=acc_top1_val * 100,
                          global_step=epoch)

            if acc_top1_val > best_val_score:
                best_val_score = acc_top1_val
                net.save_parameters('%s/%.4f-%s-%s-%03d-best.params' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
                trainer.save_states('%s/%.4f-%s-%s-%03d-best.states' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
            else:
                if opt.save_frequency and opt.save_dir and (
                        epoch + 1) % opt.save_frequency == 0:
                    net.save_parameters(
                        '%s/%s-%s-%03d.params' %
                        (opt.save_dir, opt.dataset, model_name, epoch))
                    trainer.save_states(
                        '%s/%s-%s-%03d.states' %
                        (opt.save_dir, opt.dataset, model_name, epoch))

        # save the last model
        net.save_parameters(
            '%s/%s-%s-%03d.params' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))
        trainer.save_states(
            '%s/%s-%s-%03d.states' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
    sw.close()
model_name = opt.model
if model_name.startswith('cifar_wideresnet'):
    kwargs = {'classes': classes,
              'drop_rate': opt.drop_rate}
else:
    kwargs = {'classes': classes}
net = get_model(model_name, **kwargs)
if opt.resume_from:
    net.load_parameters(opt.resume_from, ctx = context)
optimizer = 'nag'

save_period = opt.save_period
if opt.save_dir and save_period:
    save_dir = opt.save_dir
    makedirs(save_dir)
else:
    save_dir = ''
    save_period = 0

plot_path = opt.save_plot_dir

logging.basicConfig(level=logging.INFO)
logging.info(opt)

transform_train = transforms.Compose([
    gcv_transforms.RandomCrop(32, pad=4),
    transforms.RandomFlipLeftRight(),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
])
示例#38
0
    parser = argparse.ArgumentParser(
        description='Initialize Cycle Gan dataset.',
        epilog='Example: python download_dataset.py --download-dir ./',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--download-dir', type=str, default='./', help='dataset directory on disk')
    parser.add_argument('--overwrite', action='store_true', help='overwrite downloaded files if set, in case they are corrupted')
    parser.add_argument('--file',type=str,default='horse2zebra',choices=['apple2orange','summer2winter_yosemite','horse2zebra','monet2photo','cezanne2photo','ukiyoe2photo','vangogh2photo','maps','cityscapes','facades','iphone2dslr_flower','ae_photos'],
                        help='Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos')
    args = parser.parse_args()
    return args

#####################################################################################
# Download and extract VOC datasets into ``path``

def download_data(path,file, overwrite=False):
    _DOWNLOAD_URL = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
    filename = download(_DOWNLOAD_URL + file, path=path, overwrite=overwrite)
    # extract
    with zipfile.ZipFile(filename,'r') as zip:
        zip.extractall(path=path)



if __name__ == '__main__':
    args = parse_args()
    args.file = args.file + '.zip'
    path = os.path.expanduser(args.download_dir)
    if not os.path.isdir(path) :
        makedirs(path)
    download_data(path, args.file,overwrite=args.overwrite)
def main():
    opt = parse_args()

    hvd.init()

    logging_file = 'train_imagenet_%s.log' % (opt.trainer)

    filehandler = logging.FileHandler(logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    if hvd.rank() == 0:
        logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    context = [mx.gpu(hvd.local_rank())]
    num_workers = opt.num_workers

    optimizer = opt.optimizer

    warmup_epochs = opt.warmup_epochs

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // (batch_size * hvd.size())

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=opt.warmup_lr, target_lr=opt.lr,
                    nepochs=warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])


    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        kwargs['norm_layer'] = gcv.nn.GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)
    if opt.resume_params != '':
        net.load_parameters(opt.resume_params, ctx = context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,
            round_batch         = False,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
            num_parts           = hvd.size(),
            part_index          = hvd.rank(),
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data, val=True):
        if opt.use_rec:
            if val:
                val_data.reset()
            else:
                train_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1-top1, 1-top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params == '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        hvd.broadcast_parameters(net.collect_params(), root_rank=0)

        # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        # trainer = hvd.DistributedTrainer(
        #     net.collect_params(),  
        #     optimizer,
        #     optimizer_params)

        if opt.trainer == 'sgd':
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)
        elif opt.trainer == 'efsgd':
            trainer = EFSGDTrainerV1(
                net.collect_params(),  
                'EFSGDV1', optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'qsparselocalsgd':
            trainer = QSparseLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd':
            trainer = ERSGDTrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1)
        elif opt.trainer == 'partiallocalsgd':
            trainer = PartialLocalSGDTrainerV1(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio=1./opt.input_sparse_1, 
                output_sparse_ratio=1./opt.output_sparse_1, 
                layer_sparse_ratio=1./opt.layer_sparse_1,
                local_sgd_interval=opt.local_sgd_interval)
        elif opt.trainer == 'ersgd2':
            trainer = ERSGD2TrainerV2(
                net.collect_params(),  
                optimizer, optimizer_params, 
                input_sparse_ratio_1=1./opt.input_sparse_1, 
                output_sparse_ratio_1=1./opt.output_sparse_1, 
                layer_sparse_ratio_1=1./opt.layer_sparse_1,
                input_sparse_ratio_2=1./opt.input_sparse_2, 
                output_sparse_ratio_2=1./opt.output_sparse_2, 
                layer_sparse_ratio_2=1./opt.layer_sparse_2,
                local_sgd_interval=opt.local_sgd_interval)
        else:
            trainer = SGDTrainer(
                net.collect_params(),  
                optimizer, optimizer_params)

        if opt.resume_states != '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            # train_metric.reset()
            train_loss = 0
            btic = time.time()

            # test speed
            if opt.test_speed > 0:
                n_repeats = opt.test_speed
            elif opt.test_speed == 0:
                n_repeats = 1
            else:
                n_repeats = 0

            for i, batch in enumerate(train_data):
                
                # test speed
                if n_repeats == 0 and not (i+1)%opt.log_interval:
                    print('[Epoch %d] # batch: %d'%(epoch, i))
                    continue

                data, label = batch_fn(batch, ctx)

                for j in range(n_repeats):

                    if opt.mixup:
                        lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                        if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                            lam = 1
                        data = [lam*X + (1-lam)*X[::-1] for X in data]

                        if opt.label_smoothing:
                            eta = 0.1
                        else:
                            eta = 0.0
                        label = mixup_transform(label, classes, lam, eta)

                    elif opt.label_smoothing:
                        hard_label = label
                        label = smooth(label, classes)

                    if distillation:
                        teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                        for X in data]

                    with ag.record():
                        outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                        if distillation:
                            loss = [L(yhat.astype('float32', copy=False),
                                    y.astype('float32', copy=False),
                                    p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                        else:
                            loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                    for l in loss:
                        l.backward()
                    trainer.step(batch_size)

                    # if opt.mixup:
                    #     output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                    #                     for out in outputs]
                    #     train_metric.update(label, output_softmax)
                    # else:
                    #     if opt.label_smoothing:
                    #         train_metric.update(hard_label, outputs)
                    #     else:
                    #         train_metric.update(label, outputs)

                    step_loss = sum([l.sum().asscalar() for l in loss])

                    train_loss += step_loss

                    if opt.log_interval and not (i+j+1)%opt.log_interval:
                        # train_metric_name, train_metric_score = train_metric.get()
                        if hvd.rank() == 0:
                            # logger.info('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            # print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                            #             epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                            #             train_metric_name, train_metric_score, trainer.learning_rate, trainer._comm_counter/1e6))
                            print('Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f lr=%f comm=%f'%(
                                        epoch, i, batch_size*hvd.size()*opt.log_interval/(time.time()-btic),
                                        'loss', step_loss/batch_size, trainer.learning_rate, trainer._comm_counter/1e6))
                        btic = time.time()

            mx.nd.waitall()
            toc = time.time()

            if n_repeats == 0:
                allreduce_array_nd = mx.nd.array([i])
                hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
                mx.nd.waitall()
                print('[Epoch %d] # total batch: %d'%(epoch, i))
                continue

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(toc - tic) * hvd.size())

            train_loss /= (batch_size * i)

            if opt.trainer == 'ersgd' or opt.trainer == 'qsparselocalsgd' or opt.trainer == 'ersgd2' or opt.trainer == 'partiallocalsgd':
                allreduce_for_val = True
            else:
                allreduce_for_val = False

            if allreduce_for_val:
                trainer.pre_test()
            # err_train_tic = time.time()
            # err_top1_train, err_top5_train = test(ctx, train_data, val=False)
            err_val_tic = time.time()
            err_top1_val, err_top5_val = test(ctx, val_data, val=True)
            err_val_toc = time.time()
            if allreduce_for_val:
                trainer.post_test()

            mx.nd.waitall()

            # allreduce the results
            allreduce_array_nd = mx.nd.array([train_loss, err_top1_val, err_top5_val])
            hvd.allreduce_(allreduce_array_nd, name='allreduce_array', average=True)
            allreduce_array_np = allreduce_array_nd.asnumpy()
            train_loss = np.asscalar(allreduce_array_np[0])
            err_top1_val = np.asscalar(allreduce_array_np[1])
            err_top5_val = np.asscalar(allreduce_array_np[2])

            if hvd.rank() == 0:
                # logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
                logger.info('[Epoch %d] training: loss=%f'%(epoch, train_loss))
                logger.info('[Epoch %d] speed: %d samples/sec training-time: %f comm: %f'%(epoch, throughput, toc-tic, trainer._comm_counter/1e6))
                logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-time=%f'%(epoch, err_top1_val, err_top5_val, err_val_toc - err_val_tic))
                trainer._comm_counter = 0

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                # if hvd.local_rank() == 0:
                #     net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                #     trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                if hvd.local_rank() == 0:
                    net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                    trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        # if save_frequency and save_dir:
        #     if hvd.local_rank() == 0:
        #         net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
        #         trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
示例#40
0
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes,
                'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    model_name += '_mixup'
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx = context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_name = opt.save_plot_dir

    logging_handlers = [logging.StreamHandler()]
    if opt.logging_dir:
        logging_dir = opt.logging_dir
        makedirs(logging_dir)
        logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))

    logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx = label.context)
        res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer,
                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20:
                    lam = 1

                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)

                data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                label = []
                for Y in label_1:
                    y1 = label_transform(Y, classes)
                    y2 = label_transform(Y[::-1], classes)
                    label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)
            train_history.update([acc, 1-val_acc])
            train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            name, val_acc = test(ctx, val_data)
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                (epoch, acc, val_acc, train_loss, time.time()-tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
示例#41
0
                                                  opt.rec_train_idx,
                                                  opt.rec_val, opt.rec_val_idx,
                                                  batch_size, num_workers)
else:
    train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size,
                                                     num_workers)

acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
acc_top1_aux = mx.metric.Accuracy()
acc_top5_aux = mx.metric.TopKAccuracy(5)

save_frequency = opt.save_frequency
if opt.save_dir and save_frequency:
    save_dir = opt.save_dir
    makedirs(save_dir)
else:
    save_dir = ''
    save_frequency = 0


def smooth(label, classes, eta=0.1):
    if isinstance(label, nd.NDArray):
        label = [label]
    smoothed = []
    for l in label:
        ind = l.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=l.context)
        res += eta / classes
        res[nd.arange(ind.shape[0], ctx=l.context),
            ind] = 1 - eta + eta / classes
示例#42
0
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if opt.use_gn:
        from gluoncv.nn import GroupNorm
        kwargs['norm_layer'] = GroupNorm
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    if opt.last_gamma:
        kwargs['last_gamma'] = True

    optimizer = 'nag'
    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)
    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx = context)

    # teacher model for distillation training
    if opt.teacher is not None and opt.hard_weight < 1.0:
        teacher_name = opt.teacher
        teacher = get_model(teacher_name, pretrained=True, classes=classes, ctx=context)
        teacher.cast(opt.dtype)
        distillation = True
    else:
        distillation = False

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,

            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = resize,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
        resize = int(math.ceil(input_size / crop_ratio))

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(resize, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    if opt.mixup:
        train_metric = mx.metric.RMSE()
    else:
        train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def mixup_transform(label, classes, lam=1, eta=0.0):
        if isinstance(label, nd.NDArray):
            label = [label]
        res = []
        for l in label:
            y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            res.append(lam*y1 + (1-lam)*y2)
        return res

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (1-top1, 1-top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        if opt.resume_params is '':
            net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.label_smoothing or opt.mixup:
            sparse_label_loss = False
        else:
            sparse_label_loss = True
        if distillation:
            L = gcv.loss.DistillationSoftmaxCrossEntropyLoss(temperature=opt.temperature,
                                                                 hard_weight=opt.hard_weight,
                                                                 sparse_label=sparse_label_loss)
        else:
            L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)

        best_val_score = 1

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            train_metric.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                if opt.mixup:
                    lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
                    if epoch >= opt.num_epochs - opt.mixup_off_epoch:
                        lam = 1
                    data = [lam*X + (1-lam)*X[::-1] for X in data]

                    if opt.label_smoothing:
                        eta = 0.1
                    else:
                        eta = 0.0
                    label = mixup_transform(label, classes, lam, eta)

                elif opt.label_smoothing:
                    hard_label = label
                    label = smooth(label, classes)

                if distillation:
                    teacher_prob = [nd.softmax(teacher(X.astype(opt.dtype, copy=False)) / opt.temperature) \
                                    for X in data]

                with ag.record():
                    outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    if distillation:
                        loss = [L(yhat.astype('float32', copy=False),
                                  y.astype('float32', copy=False),
                                  p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
                    else:
                        loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                if opt.mixup:
                    output_softmax = [nd.SoftmaxActivation(out.astype('float32', copy=False)) \
                                    for out in outputs]
                    train_metric.update(label, output_softmax)
                else:
                    if opt.label_smoothing:
                        train_metric.update(hard_label, outputs)
                    else:
                        train_metric.update(label, outputs)

                if opt.log_interval and not (i+1)%opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
                                epoch, i, batch_size*opt.log_interval/(time.time()-btic),
                                train_metric_name, train_metric_score, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(time.time() - tic))

            err_top1_val, err_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

            if err_top1_val < best_val_score:
                best_val_score = err_top1_val
                net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
                trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
                trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
            trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))


    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
        if distillation:
            teacher.hybridize(static_alloc=True, static_shape=True)
    train(context)
def main():
    opt = parse_args()

    batch_size = opt.batch_size
    classes = 10

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]

    model_name = opt.model
    if model_name.startswith('cifar_wideresnet'):
        kwargs = {'classes': classes, 'drop_rate': opt.drop_rate}
    else:
        kwargs = {'classes': classes}
    net = get_model(model_name, **kwargs)
    if opt.resume_from:
        net.load_parameters(opt.resume_from, ctx=context)
    optimizer = 'nag'

    save_period = opt.save_period
    if opt.save_dir and save_period:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_period = 0

    plot_path = opt.save_plot_dir

    logging.basicConfig(level=logging.INFO)
    logging.info(opt)

    transform_train = transforms.Compose([
        gcv_transforms.RandomCrop(32, pad=4),
        transforms.RandomFlipLeftRight(),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=ctx,
                                               batch_axis=0)
            outputs = [net(X) for X in data]
            metric.update(label, outputs)
        return metric.get()

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.Xavier(), ctx=ctx)

        train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=True).take(256).transform_first(transform_train),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           last_batch='discard',
                                           num_workers=num_workers)

        val_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(
            train=False).take(64).transform_first(transform_test),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

        trainer = gluon.Trainer(net.collect_params(), optimizer, {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        })
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

        iteration = 0
        lr_decay_count = 0

        best_val_score = 0

        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data = gluon.utils.split_and_load(batch[0],
                                                  ctx_list=ctx,
                                                  batch_axis=0)
                label = gluon.utils.split_and_load(batch[1],
                                                   ctx_list=ctx,
                                                   batch_axis=0)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1

            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(ctx, val_data)

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-cifar-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))

            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))

        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize()
    train(opt.num_epochs, context)
def main():
    opt = parse_args()

    filehandler = logging.FileHandler(opt.logging_file)
    streamhandler = logging.StreamHandler()

    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)

    logger.info(opt)

    batch_size = opt.batch_size
    classes = 1000
    num_training_samples = 1281167

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]
    num_batches = num_training_samples // batch_size

    lr_scheduler = LRSequential([
        LRScheduler('linear', base_lr=0, target_lr=opt.lr,,
                    nepochs=opt.warmup_epochs, iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode, base_lr=opt.lr, target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay, power=2)
    ])

    model_name = opt.model

    kwargs = {'ctx': context, 'pretrained': opt.use_pretrained, 'classes': classes}
    if model_name.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm
    elif model_name.startswith('resnext'):
        kwargs['use_se'] = opt.use_se

    optimizer = 'nag'
    optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum, 'lr_scheduler': lr_scheduler}
    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    net = get_model(model_name, **kwargs)
    net.cast(opt.dtype)

    # Two functions for reading data from record file or raw images
    def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
        rec_train = os.path.expanduser(rec_train)
        rec_train_idx = os.path.expanduser(rec_train_idx)
        rec_val = os.path.expanduser(rec_val)
        rec_val_idx = os.path.expanduser(rec_val_idx)
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size
        mean_rgb = [123.68, 116.779, 103.939]
        std_rgb = [58.393, 57.12, 57.375]

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            return data, label

        train_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_train,
            path_imgidx         = rec_train_idx,
            preprocess_threads  = num_workers,
            shuffle             = True,
            batch_size          = batch_size,

            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
            rand_mirror         = True,
            random_resized_crop = True,
            max_aspect_ratio    = 4. / 3.,
            min_aspect_ratio    = 3. / 4.,
            max_random_area     = 1,
            min_random_area     = 0.08,
            brightness          = jitter_param,
            saturation          = jitter_param,
            contrast            = jitter_param,
            pca_noise           = lighting_param,
        )
        val_data = mx.io.ImageRecordIter(
            path_imgrec         = rec_val,
            path_imgidx         = rec_val_idx,
            preprocess_threads  = num_workers,
            shuffle             = False,
            batch_size          = batch_size,

            resize              = 256,
            data_shape          = (3, input_size, input_size),
            mean_r              = mean_rgb[0],
            mean_g              = mean_rgb[1],
            mean_b              = mean_rgb[2],
            std_r               = std_rgb[0],
            std_g               = std_rgb[1],
            std_b               = std_rgb[2],
        )
        return train_data, val_data, batch_fn

    def get_data_loader(data_dir, batch_size, num_workers):
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        jitter_param = 0.4
        lighting_param = 0.1
        input_size = opt.input_size

        def batch_fn(batch, ctx):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
            return data, label

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
                                        saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            normalize
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256, keep_ratio=True),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize
        ])

        train_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
        val_data = gluon.data.DataLoader(
            imagenet.classification.ImageNet(data_dir, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        return train_data, val_data, batch_fn

    if opt.use_rec:
        train_data, val_data, batch_fn = get_data_rec(opt.rec_train, opt.rec_train_idx,
                                                    opt.rec_val, opt.rec_val_idx,
                                                    batch_size, num_workers)
    else:
        train_data, val_data, batch_fn = get_data_loader(opt.data_dir, batch_size, num_workers)

    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)
    acc_top1_aux = mx.metric.Accuracy()
    acc_top5_aux = mx.metric.TopKAccuracy(5)

    save_frequency = opt.save_frequency
    if opt.save_dir and save_frequency:
        save_dir = opt.save_dir
        makedirs(save_dir)
    else:
        save_dir = ''
        save_frequency = 0

    def smooth(label, classes, eta=0.1):
        if isinstance(label, nd.NDArray):
            label = [label]
        smoothed = []
        for l in label:
            ind = l.astype('int')
            res = nd.zeros((ind.shape[0], classes), ctx = l.context)
            res += eta/classes
            res[nd.arange(ind.shape[0], ctx = l.context), ind] = 1 - eta + eta/classes
            smoothed.append(res)
        return smoothed

    def test(ctx, val_data):
        if opt.use_rec:
            val_data.reset()
        acc_top1.reset()
        acc_top5.reset()
        acc_top1_aux.reset()
        acc_top5_aux.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, [o[0] for o in outputs])
            acc_top5.update(label, [o[0] for o in outputs])
            acc_top1_aux.update(label, [o[1] for o in outputs])
            acc_top5_aux.update(label, [o[1] for o in outputs])

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        _, top1_aux = acc_top1_aux.get()
        _, top5_aux = acc_top5_aux.get()
        return (1-top1, 1-top5, 1-top1_aux, 1-top5_aux)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]
        net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

        trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
        if opt.label_smoothing:
            L = MixSoftmaxCrossEntropyLoss(sparse_label=False, aux_weight=0.4)
        else:
            L = MixSoftmaxCrossEntropyLoss(aux_weight=0.4)

        best_val_score = 1

        for epoch in range(opt.num_epochs):
            tic = time.time()
            if opt.use_rec:
                train_data.reset()
            acc_top1.reset()
            acc_top5.reset()
            acc_top1_aux.reset()
            acc_top5_aux.reset()
            btic = time.time()

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)
                if opt.label_smoothing:
                    label_smooth = smooth(label, classes)
                else:
                    label_smooth = label
                with ag.record():
                    outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    loss = [L(yhat[0], yhat[1], y) for yhat, y in zip(outputs, label_smooth)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)

                acc_top1.update(label, [o[0] for o in outputs])
                acc_top5.update(label, [o[0] for o in outputs])
                acc_top1_aux.update(label, [o[1] for o in outputs])
                acc_top5_aux.update(label, [o[1] for o in outputs])
                if opt.log_interval and not (i+1)%opt.log_interval:
                    _, top1 = acc_top1.get()
                    _, top5 = acc_top5.get()
                    _, top1_aux = acc_top1_aux.get()
                    _, top5_aux = acc_top5_aux.get()
                    err_top1, err_top5, err_top1_aux, err_top5_aux = (1-top1, 1-top5, 1-top1_aux, 1-top5_aux)
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t'
                                'top1-err=%f\ttop5-err=%f\ttop1-err-aux=%f\ttop5-err-aux=%f'%(
                                epoch, i, batch_size*opt.log_interval/(time.time()-btic),
                                err_top1, err_top5, err_top1_aux, err_top5_aux))
                    btic = time.time()

            _, top1 = acc_top1.get()
            _, top5 = acc_top5.get()
            _, top1_aux = acc_top1_aux.get()
            _, top5_aux = acc_top5_aux.get()
            err_top1, err_top5, err_top1_aux, err_top5_aux = (1-top1, 1-top5, 1-top1_aux, 1-top5_aux)

            err_top1_val, err_top5_val, err_top1_val_aux, err_top5_val_aux = test(ctx, val_data)

            logger.info('[Epoch %d] training: err-top1=%f err-top5=%f err-top1_aux=%f err-top5_aux=%f'%
                (epoch, err_top1, err_top5, err_top1_aux, err_top5_aux))
            logger.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
            logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f err-top1_aux=%f err-top5_aux=%f'%
                (epoch, err_top1_val, err_top5_val, err_top1_val_aux, err_top5_val_aux))

            if err_top1_val < best_val_score and epoch > 50:
                best_val_score = err_top1_val
                net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))

            if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
                net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))

        if save_frequency and save_dir:
            net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))


    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)
    train(context)