Esempio n. 1
0
def main():
    params = Params()
    paramparse.process(params)

    # for _sub_seq in params.sub_seq_dict:
    #     params.sub_seq = _sub_seq
    #     _convert_dataset(params)

    _convert_dataset(params)
Esempio n. 2
0
def main():
    params = Params()

    paramparse.process(params, allow_unknown=1)

    params.scp.read_auth()

    if params.working_dir:
        os.chdir(params.working_dir)

    timestamp = datetime.now().strftime("%y%m%d_%H%M%S_%f")

    log_data_all = []
    log_data_dict = {}
    out_dir = linux_path(params.log_dir, 'consolidate_log')
    os.makedirs(out_dir, exist_ok=True)

    for server_id, server_name in enumerate(params.servers):
        log_path = run_scp(params.scp,
                           server_name,
                           params.log_dir,
                           params.log_fname,
                           out_dir,
                           is_file=1,
                           timestamp=timestamp)

        if log_path is None:
            print('log data for server {} not found'.format(server_name))
            continue

        log_data = open(log_path, 'r').readlines()

        if server_id > 0 and params.remove_header:
            log_data = log_data[1:]

        log_data_all += log_data

        log_data_dict[server_name] = log_data

    # out_fname = add_suffix(params.log_fname, '{}'.format(timestamp))
    out_fname = 'consolidated_{}.log'.format(timestamp)
    out_path = linux_path(out_dir, out_fname)

    print('writing consolidated log to {}'.format(out_path))
    os.makedirs(params.log_dir, exist_ok=True)
    with open(out_path, 'w') as fid:
        fid.write(''.join(log_data_all))

    open_cmd = "start {}".format(out_path)
    os.system(open_cmd)

    _ = input('press enter to exit')
Esempio n. 3
0
def main():
    params = Params()

    paramparse.process(params, allow_unknown=1)

    params.scp.read_auth()

    if params.working_dir:
        os.chdir(params.working_dir)

    timestamp = datetime.now().strftime("%y%m%d_%H%M%S_%f")

    while True:
        print('\n' + timestamp + '\n')
        run_scp(params.scp)
Esempio n. 4
0
def main():
    params = Params()
    process(params)

    src_path = params.src_path

    if not src_path:
        try:
            from Tkinter import Tk
        except ImportError:
            from tkinter import Tk
        try:
            src_path = Tk().clipboard_get()
        except BaseException as e:
            print('Tk().clipboard_get() failed: {}'.format(e))
            return

    src_path = src_path.replace(os.sep, '/').replace('"', '')
    assert os.path.exists(src_path), "src_path does not exist: {}".format(
        src_path)

    src_dir = os.path.dirname(src_path)
    src_name = os.path.basename(src_path)

    src_name_noext, src_ext = os.path.splitext(src_name)

    out_name = src_name_noext + '_' + src_ext

    out_path = os.path.join(src_dir, out_name)

    assert not os.path.exists(out_path), "out_path already exists"

    img = cv2.imread(src_path)

    img_h, img_w = img.shape[:2]

    out_h = out_w = 0

    if img_h > img_w:
        out_h = params.out_size
    else:
        out_w = params.out_size

    resized_img = resizeAR(img, width=out_w, height=out_h)

    cv2.imwrite(out_path, resized_img)
Esempio n. 5
0
def main():
    params = RTPNParams()
    paramparse.process(params)

    root_dir = params.root_dir
    io_file = params.io_file
    src_file_path = params.src_file_path

    if root_dir:
        src_file_path = os.path.join(root_dir, src_file_path)
        io_file = os.path.join(root_dir, io_file)

    src_file_path = os.path.abspath(src_file_path)

    with open(io_file, 'r') as fid:
        io_lines = fid.readlines()

        io_lines = [
            k.strip() for k in io_lines if k.strip() and not k.startswith('#')
        ]
        io_str = '\n'.join(io_lines)
        io_str = '[{}]'.format(io_str.strip())

        ops = literal_eval(io_str)
        # io_dict = json.load(fid)

    src_book = load_workbook(src_file_path)

    for op in ops:
        process_io_dict(op, src_book)

    if not params.in_place:
        src_file_name = os.path.splitext(os.path.basename(src_file_path))[0]
        src_file_dir = os.path.dirname(src_file_path)
        dst_file_name = src_file_name + '_proc.xlsx'
        dst_file_path = os.path.join(src_file_dir, dst_file_name)

        src_book.save(dst_file_path)

    else:
        # dst_file_path = src_file_path
        src_book.save(src_file_path)
Esempio n. 6
0
def main():
    params = Params()
    paramparse.process(params)

    list_fname = params.list_fname

    while True:
        if os.path.isfile(list_fname):
            print('reading tb log dir list from: {}'.format(list_fname))
            data = open(list_fname, 'r').readlines()

            data_list = [k.strip().split('\t') for k in data]

            if params.start_id > 0:
                data_list = data_list[params.start_id:]

            data_list_str = ['{}:{}'.format(k[0], k[1]) for k in data_list if os.path.isdir(k[1])]
            log_dirs = ','.join(data_list_str)
            log_dirs_arg = '--logdir_spec'

        else:
            log_dirs = list_fname

            log_dirs_arg = '--logdir'

        tb_cmd = "{} {} {}={} --bind_all --samples_per_plugin images={}".format(
            params.python_exe, params.tb_path, log_dirs_arg, log_dirs, params.images,
        )

        print('running: {}'.format(tb_cmd))

        os.system(tb_cmd)

        list_fname = input('\nEnter log folder / list file name\n')

        if not list_fname.strip():
            list_fname = params.list_fname
def main():
    params = NewDeeplabParams()
    paramparse.process(params)

    params.process()

    phases = params.phases
    if params.start > 0:
        phases = phases[params.start:]

    if params.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = params.gpu

    if Phases.train in phases:
        train.run(params.train)

    if Phases.raw_vis in phases:
        raw_vis.run(params.raw_vis)

    if Phases.stitch in phases:
        stitch.run(params.stitch)

    if Phases.vis in phases:
        vis.run(params.vis)
def main():
    # parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
    # parser.add_argument('--backbone', type=str, default='resnet',
    #                     choices=['resnet', 'xception', 'drn', 'mobilenet'],
    #                     help='backbone name (default: resnet)')
    # parser.add_argument('--out_stride', type=int, default=16,
    #                     help='network output stride (default: 8)')
    # parser.add_argument('--dataset', type=str, default='pascal',
    #                     choices=['pascal', 'coco', 'cityscapes'],
    #                     help='dataset name (default: pascal)')
    # parser.add_argument('--use_sbd', action='store_true', default=False,
    #                     help='whether to use SBD dataset (default: True)')
    # parser.add_argument('--workers', type=int, default=4,
    #                     metavar='N', help='dataloader threads')
    # parser.add_argument('--base_size', type=int, default=320,
    #                     help='base image size')
    # parser.add_argument('--crop_size', type=int, default=320,
    #                     help='crop image size')
    # parser.add_argument('--resize', type=int, default=512,
    #                     help='resize image size')
    # parser.add_argument('--sync_bn', type=bool, default=None,
    #                     help='whether to use sync bn (default: auto)')
    # parser.add_argument('--freeze_bn', type=bool, default=False,
    #                     help='whether to freeze bn parameters (default: False)')
    # parser.add_argument('--loss_type', type=str, default='ce',
    #                     choices=['ce', 'focal'],
    #                     help='loss func type (default: ce)')
    # # training hyper params
    # parser.add_argument('--epochs', type=int, default=None, metavar='N',
    #                     help='number of epochs to train (default: auto)')
    # parser.add_argument('--start_epoch', type=int, default=0,
    #                     metavar='N', help='start epochs (default:0)')
    # parser.add_argument('--batch_size', type=int, default=None,
    #                     metavar='N', help='input batch size for \
    #                             training (default: auto)')
    # parser.add_argument('--test_batch_size', type=int, default=None,
    #                     metavar='N', help='input batch size for \
    #                             testing (default: auto)')
    # parser.add_argument('--use_balanced_weights', action='store_true', default=False,
    #                     help='whether to use balanced weights (default: False)')
    # # optimizer params
    # parser.add_argument('--lr', type=float, default=0.025, metavar='LR',
    #                     help='learning rate (default: auto)')
    # parser.add_argument('--arch_lr', type=float, default=3e-3,
    #                     help='learning rate for alpha and beta in architect searching process')
    #
    # parser.add_argument('--lr_scheduler', type=str, default='cos',
    #                     choices=['poly', 'step', 'cos'],
    #                     help='lr scheduler mode: (default: cos)')
    # parser.add_argument('--momentum', type=float, default=0.9,
    #                     metavar='M', help='momentum (default: 0.9)')
    # parser.add_argument('--weight_decay', type=float, default=3e-4,
    #                     metavar='M', help='w-decay (default: 5e-4)')
    # parser.add_argument('--arch_weight_decay', type=float, default=1e-3,
    #                     metavar='M', help='w-decay (default: 5e-4)')
    #
    # parser.add_argument('--nesterov', action='store_true', default=False,
    #                     help='whether use nesterov (default: False)')
    # # cuda, seed and logging
    # parser.add_argument('--no_cuda', action='store_true', default=
    # False, help='disables CUDA training')
    # parser.add_argument('--gpu-ids', nargs='*', type=int, default=0,
    #                     help='which GPU to train on (default: 0)')
    # parser.add_argument('--seed', type=int, default=1, metavar='S',
    #                     help='random seed (default: 1)')
    # # checking point
    # parser.add_argument('--resume', type=str, default=None,
    #                     help='put the path to resuming file if needed')
    # parser.add_argument('--checkname', type=str, default=None,
    #                     help='set the checkpoint name')
    # # finetuning pre-trained models
    # parser.add_argument('--ft', action='store_true', default=False,
    #                     help='finetuning on a different dataset')
    # # evaluation option
    # parser.add_argument('--eval_interval', type=int, default=1,
    #                     help='evaluuation interval (default: 1)')
    # parser.add_argument('--no_val', action='store_true', default=False,
    #                     help='skip validation during training')
    #
    # paramparse.fromParser(parser, 'AutoDeeplabParams')
    #
    # args = parser.parse_args()

    args = AutoDeeplabParams()
    paramparse.process(args)

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    # default settings for epochs, batch_size and lr
    if args.epochs is None:
        epoches = {
            'coco': 30,
            'cityscapes': 200,
            'pascal': 50,
        }
        args.epochs = epoches[args.dataset.lower()]

    if args.batch_size is None:
        args.batch_size = 4 * len(args.gpu_ids)

    if args.test_batch_size is None:
        args.test_batch_size = args.batch_size

    if args.lr is None:
        lrs = {
            'coco': 0.1,
            'cityscapes': 0.025,
            'pascal': 0.007,
        }
        # args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size

    if args.checkname is None:
        args.checkname = 'deeplab-' + str(args.backbone)
    print(args)
    torch.manual_seed(args.seed)
    trainer = Trainer(args)
    print('Starting Epoch:', trainer.args.start_epoch)
    print('Total Epoches:', trainer.args.epochs)
    for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
        trainer.training(epoch)
        if not trainer.args.no_val and epoch % args.eval_interval == (
                args.eval_interval - 1):
            trainer.validation(epoch)

    trainer.writer.close()
Esempio n. 9
0
def main():
    params = Params()

    paramparse.process(params)

    root_dirs = params.root_dirs
    annotations = params.annotations
    save_path = params.save_path
    img_ext = params.img_ext
    show_img = params.show_img
    del_src = params.del_src
    start_id = params.start_id
    n_frames = params.n_frames
    width = params.width
    height = params.height
    fps = params.fps
    codec = params.codec
    ext = params.ext
    out_width = params.out_width
    out_height = params.out_height
    grid_size = params.grid_size
    sep_size = params.sep_size
    only_height = params.only_height
    borderless = params.borderless
    preserve_order = params.preserve_order
    ann_fmt = params.ann_fmt
    resize_factor = params.resize_factor
    recursive = params.recursive
    img_seq = params.img_seq
    match_images = params.match_images

    if match_images:
        assert img_seq, "image matching is only supported in image sequence mode"

    vid_exts = ['.mkv', '.mp4', '.avi', '.mjpg', '.wmv']
    image_exts = ['.jpg', '.bmp', '.png', '.tif']

    min_n_sources = None
    min_sources = None
    min_sources_id = None

    sources_list = []
    for root_dir_id, root_dir in enumerate(root_dirs):
        root_dir = os.path.abspath(root_dir)
        if img_seq:
            sources = [k for k in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, k))]
        else:
            sources = [k for k in os.listdir(root_dir) if os.path.splitext(k)[1] in vid_exts]
        sources.sort()

        n_sources = len(sources)
        if min_n_sources is None or n_sources < min_n_sources:
            min_n_sources = n_sources
            min_sources = sources
            min_sources_id = root_dir_id

        sources = [os.path.join(root_dir, k) for k in sources]

        sources_list.append(sources)

    if match_images:
        for sources_id, sources in enumerate(sources_list):
            if sources_id == min_sources_id:
                continue

            sources = [k for k in sources if os.path.basename(k) in min_sources]

            assert len(sources) == min_n_sources, "invalid sources after filtering {}".format(sources)

            sources_list[sources_id] = sources

    src_paths = list(zip(*sources_list))

    print('sources_list:\n{}'.format(pformat(sources_list)))
    print('src_paths:\n{}'.format(pformat(src_paths)))

    timestamp = datetime.now().strftime("%y%m%d_%H%M%S")

    _exit = 0

    for _src_path in src_paths:
        _annotations = annotations
        _save_path = save_path
        _grid_size = grid_size
        n_frames = 0

        src_files = _src_path

        n_videos = len(src_files)
        assert n_videos > 0, 'no input videos found'

        if not _save_path:
            seq_dir = os.path.dirname(src_files[0])
            seq_name = os.path.splitext(os.path.basename(src_files[0]))[0]
            dst_path = os.path.join(seq_dir, 'stacked_{}'.format(timestamp), '{}.{}'.format(seq_name, ext))
        else:
            out_seq_name, out_ext = os.path.splitext(os.path.basename(_save_path))
            dst_path = os.path.join(os.path.dirname(_save_path), '{}_{}{}'.format(
                out_seq_name, datetime.now().strftime("%y%m%d_%H%M%S"), out_ext))

        save_dir = os.path.dirname(dst_path)
        if save_dir and not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        print('Stacking: {} videos:'.format(n_videos))
        print('src_files:\n{}'.format(pformat(src_files)))

        if _annotations:
            if len(_annotations) == 1 and _annotations[0] == 1:
                _annotations = []
                for i in range(n_videos):
                    _annotations.append(seq_names[i])
            else:
                assert len(_annotations) == n_videos, 'Invalid annotations: {}'.format(_annotations)

                for i in range(n_videos):
                    if _annotations[i] == '__n__':
                        _annotations[i] = ''

            print('Adding annotations:\n{}'.format(pformat(_annotations)))
        else:
            _annotations = None

        if not _grid_size:
            _grid_size = None
        else:
            _grid_size = [int(x) for x in _grid_size.split('x')]
            if len(_grid_size) != 2 or _grid_size[0] * _grid_size[1] != n_videos:
                raise AssertionError('Invalid grid_size: {}'.format(_grid_size))

        n_frames_list = []
        cap_list = []
        size_list = []
        seq_names = []

        min_n_frames = None
        min_n_frames_id = 0

        for src_id, src_file in enumerate(src_files):
            src_file = os.path.abspath(src_file)
            seq_name = os.path.splitext(os.path.basename(src_file))[0]

            seq_names.append(seq_name)

            if os.path.isfile(src_file):
                cap = cv2.VideoCapture()
            elif os.path.isdir(src_file):
                cap = ImageSequenceCapture(src_file, recursive=recursive)
            else:
                raise IOError('Invalid src_file: {}'.format(src_file))

            if not cap.open(src_file):
                raise IOError('The video file ' + src_file + ' could not be opened')

            cv_prop = cv2.CAP_PROP_FRAME_COUNT
            h_prop = cv2.CAP_PROP_FRAME_HEIGHT
            w_prop = cv2.CAP_PROP_FRAME_WIDTH

            total_frames = int(cap.get(cv_prop))
            _height = int(cap.get(h_prop))
            _width = int(cap.get(w_prop))

            cap_list.append(cap)
            n_frames_list.append(total_frames)
            if min_n_frames is None or total_frames < min_n_frames:
                min_n_frames = total_frames
                min_n_frames_id = src_id

            size_list.append((_width, _height))

        if match_images:
            assert all(seq_name == seq_names[0] for seq_name in seq_names), "mismatch in seq_names: {}".format(seq_names)

            frames_list = [os.path.basename(k) for k in cap_list[min_n_frames_id].src_files]
            for src_id, cap in enumerate(cap_list):
                if src_id == min_n_frames_id:
                    continue
                cap_list[src_id].filter_files(frames_list)
                n_frames_list[src_id] = min_n_frames

        frame_id = start_id
        pause_after_frame = 0
        video_out = None

        win_name = 'stacked_{}'.format(datetime.now().strftime("%y%m%d_%H%M%S"))

        min_n_frames = min(n_frames_list)
        max_n_frames = max(n_frames_list)

        if n_frames <= 0:
            n_frames = max_n_frames
        else:
            if max_n_frames < n_frames:
                raise IOError(
                    'Invalid n_frames: {} for sequence list with max_n_frames: {}'.format(n_frames, max_n_frames))

        if show_img == 2:
            vis_only = True
            print('Running in visualization only mode')
        else:
            vis_only = False

        while True:

            images = []
            valid_caps = []
            valid_annotations = []
            for cap_id, cap in enumerate(cap_list):
                ret, image = cap.read()
                if not ret:
                    print('\nFrame {:d} could not be read'.format(frame_id + 1))
                    continue
                images.append(image)
                valid_caps.append(cap)
                if _annotations:
                    valid_annotations.append(_annotations[cap_id])

            cap_list = valid_caps
            if _annotations:
                _annotations = valid_annotations

            # if len(images) != n_videos:
            #     break

            frame_id += 1

            if frame_id <= start_id:
                break

            out_img = stackImages(images, _grid_size, borderless=borderless, preserve_order=preserve_order,
                                  annotations=_annotations, ann_fmt=ann_fmt, only_height=only_height, sep_size=sep_size)
            if resize_factor != 1:
                out_img = cv2.resize(out_img, (0, 0), fx=resize_factor, fy=resize_factor)

            if not vis_only:
                if video_out is None:
                    dst_height, dst_width = sizeAR(out_img, width=out_width, height=out_height)

                    if '.' + ext in vid_exts:
                        fourcc = cv2.VideoWriter_fourcc(*codec)
                        video_out = cv2.VideoWriter(dst_path, fourcc, fps, (dst_width, dst_height))
                    elif '.' + ext in image_exts:
                        video_out = ImageSequenceWriter(dst_path, height=dst_height, width=dst_width)
                    else:
                        raise IOError('Invalid ext: {}'.format(ext))

                    if video_out is None:
                        raise IOError('Output video file could not be opened: {}'.format(dst_path))

                    print('Saving {}x{} output video to {}'.format(dst_width, dst_height, dst_path))

                out_img = resizeAR(out_img, width=dst_width, height=dst_height)
                video_out.write(out_img)

            if show_img:
                # out_img_disp = out_img
                out_img_disp = resizeAR(out_img, 1280)
                cv2.imshow(win_name, out_img_disp)
                k = cv2.waitKey(1 - pause_after_frame) & 0xFF
                if k == ord('q'):
                    _exit = 1
                    break
                elif k == 27:
                    break
                elif k == 32:
                    pause_after_frame = 1 - pause_after_frame

            sys.stdout.write('\rDone {:d}/{:d} frames '.format(frame_id - start_id, n_frames))
            sys.stdout.flush()

            if frame_id - start_id >= n_frames:
                break

        if _exit:
            break

        sys.stdout.write('\n')
        sys.stdout.flush()

        video_out.release()

        if show_img:
            cv2.destroyWindow(win_name)
Esempio n. 10
0
            out_txt = copy_from_clipboard()
            # print('out_txt: {}'.format(out_txt))
        else:
            key = load_key(params.key_file)
            decrypted_data = decrypt(params.in_file, key)
            out_txt = decrypted_data.decode('ascii')
            # print('out_txt: {}'.format(out_txt))

        if params.clipboard:
            if params.clipboard == 1:
                copy_to_clipboard(out_txt)
            elif params.clipboard == 2:
                type_string(
                    out_txt,
                    # params.auto_switch,
                    press_enter=params.press_enter)
        else:
            # write the original file
            with open(params.out_file, "wb") as file:
                file.write(decrypted_data)

        return out_txt


if __name__ == '__main__':
    _params = Params()
    paramparse.process(_params)
    _params.process()

    run(_params)
Esempio n. 11
0
def main():
    _params = Params()

    paramparse.process(_params)

    root_dir = _params.root_dir
    start_id = _params.start_id
    end_id = _params.end_id

    write_img = _params.write_img
    write_gt = _params.write_gt

    save_img = _params.save_img
    save_vid = _params.save_vid
    codec = _params.codec

    show_img = _params.show_img
    vis_height = _params.vis_height
    vis_width = _params.vis_width

    # default_obj_size = _params.default_obj_size
    ignore_missing_gt = _params.ignore_missing_gt
    ignore_missing_seg = _params.ignore_missing_seg
    raad_gt = _params.raad_gt

    if save_img:
        if not show_img:
            show_img += 2
    else:
        if show_img > 1:
            save_img = 1

    params = ParamDict()

    actor = 'CTC'
    actor_sequences_dict = params.sequences_ctc

    actor_sequences = list(actor_sequences_dict.keys())

    if end_id <= start_id:
        end_id = len(actor_sequences) - 1

    print('root_dir: {}'.format(root_dir))
    print('start_id: {}'.format(start_id))
    print('end_id: {}'.format(end_id))

    print('actor: {}'.format(actor))
    print('actor_sequences: {}'.format(actor_sequences))
    img_exts = ('.tif', )

    n_frames_list = []
    _pause = 1
    __pause = 1

    ann_cols = ('green', 'blue', 'red', 'cyan', 'magenta', 'gold', 'purple',
                'peach_puff', 'azure', 'dark_slate_gray', 'navy', 'turquoise')

    out_img_tif_root_path = linux_path(root_dir, actor, 'Images_TIF')
    os.makedirs(out_img_tif_root_path, exist_ok=True)

    out_img_jpg_root_path = linux_path(root_dir, actor, 'Images')
    os.makedirs(out_img_jpg_root_path, exist_ok=True)

    if save_img:
        out_vis_root_path = linux_path(root_dir, actor, 'Visualizations')
        os.makedirs(out_vis_root_path, exist_ok=True)

    out_gt_root_path = linux_path(root_dir, actor, 'Annotations')
    os.makedirs(out_gt_root_path, exist_ok=True)

    n_frames_out_file = linux_path(root_dir, actor, 'n_frames.txt')
    n_frames_out_fid = open(n_frames_out_file, 'w')

    _exit = 0
    _pause = 1
    time_stamp = datetime.now().strftime("%y%m%d_%H%M%S")

    log_path = linux_path(root_dir, actor, 'log_{}.log'.format(time_stamp))
    tif_root_dir = linux_path(root_dir, actor, 'tif')
    assert os.path.exists(tif_root_dir), "tif_root_dir does not exist"

    seq_ids = _params.seq_ids
    if not seq_ids:
        seq_ids = list(range(start_id, end_id + 1))

    n_seq = len(seq_ids)
    for __id, seq_id in enumerate(seq_ids):

        seq_name = actor_sequences[seq_id]

        default_obj_size = actor_sequences_dict[seq_name]

        seq_img_path = linux_path(tif_root_dir, seq_name)
        assert os.path.exists(seq_img_path), "seq_img_path does not exist"

        seq_img_src_files = [
            k for k in os.listdir(seq_img_path)
            if os.path.splitext(k.lower())[1] in img_exts
        ]
        seq_img_src_files.sort()

        out_gt_fid = None

        n_frames = len(seq_img_src_files)

        print('seq {} / {}\t{}\t{}\t{} frames'.format(__id + 1, n_seq, seq_id,
                                                      seq_name, n_frames))

        n_frames_out_fid.write("{:d}: ('{:s}', {:d}),\n".format(
            seq_id, seq_name, n_frames))

        n_frames_list.append(n_frames)

        gt_available = 0
        if not raad_gt:
            print('skipping GT reading')
        else:
            seq_gt_path = linux_path(tif_root_dir, seq_name + '_GT', 'TRA')
            if not os.path.exists(seq_gt_path):
                msg = "seq_gt_path does not exist"
                if ignore_missing_gt:
                    print(msg)
                else:
                    raise AssertionError(msg)
            else:
                gt_available = 1
                seq_gt_tra_file = linux_path(seq_gt_path, "man_track.txt")
                if os.path.exists(seq_gt_tra_file):
                    out_tra_file = linux_path(out_gt_root_path,
                                              seq_name + '.tra')
                    print('{} --> {}'.format(seq_gt_tra_file, out_tra_file))
                    if not os.path.exists(out_tra_file):
                        shutil.copy(seq_gt_tra_file, out_tra_file)
                    else:
                        print('skipping existing {}'.format(out_tra_file))
                else:
                    msg = "\nseq_gt_tra_file does not exist: {}".format(
                        seq_gt_tra_file)
                    if ignore_missing_gt:
                        print(msg)
                    else:
                        raise AssertionError(msg)

        if _params.tra_only:
            continue

        seg_available = 0
        if not raad_gt:
            print('skipping segmentation reading')
        else:
            seq_seg_path = linux_path(tif_root_dir, seq_name + '_ST', 'SEG')
            if not os.path.exists(seq_seg_path):
                print("ST seq_seg_path does not exist")
                seq_seg_path = linux_path(tif_root_dir, seq_name + '_GT',
                                          'SEG')
                if not os.path.exists(seq_seg_path):
                    msg = "GT seq_seg_path does not exist"
                    if ignore_missing_seg:
                        print(msg)
                    else:
                        raise AssertionError(msg)
                else:
                    seg_available = 1
            else:
                seg_available = 1

        if write_img:
            out_img_tif_dir_path = linux_path(out_img_tif_root_path, seq_name)
            os.makedirs(out_img_tif_dir_path, exist_ok=True)
            print('copying TIF images to {}'.format(out_img_tif_dir_path))

            out_img_jpg_dir_path = linux_path(out_img_jpg_root_path, seq_name)
            os.makedirs(out_img_jpg_dir_path, exist_ok=True)
            print('Saving jPG images to {}'.format(out_img_jpg_dir_path))

        vid_out = None
        if save_img:
            if save_vid:
                out_vis_path = linux_path(out_vis_root_path, seq_name + '.mkv')
                vid_out = cv2.VideoWriter(out_vis_path,
                                          cv2.VideoWriter_fourcc(*codec), 30,
                                          (vis_width, vis_height))
            else:
                out_vis_path = linux_path(out_vis_root_path, seq_name)
                os.makedirs(out_vis_path, exist_ok=True)
            print('Saving visualizations to {}'.format(out_vis_path))

        from collections import OrderedDict

        file_id_to_gt = OrderedDict()
        obj_id_to_gt_file_ids = OrderedDict()

        if gt_available:
            print('reading GT from {}...'.format(seq_gt_path))

            seq_gt_src_files = [
                k for k in os.listdir(seq_gt_path)
                if os.path.splitext(k.lower())[1] in img_exts
            ]
            seq_gt_src_files.sort()

            assert len(seq_img_src_files) == len(
                seq_gt_src_files
            ), "mismatch between the lengths of seq_img_src_files and seq_gt_src_files"

            for seq_gt_src_file in tqdm(seq_gt_src_files,
                                        disable=_params.disable_tqdm):
                seq_gt_src_file_id = ''.join(k for k in seq_gt_src_file
                                             if k.isdigit())

                file_id_to_gt[seq_gt_src_file_id] = OrderedDict()

                seq_gt_src_path = os.path.join(seq_gt_path, seq_gt_src_file)
                seq_gt_pil = Image.open(seq_gt_src_path)
                seq_gt_np = np.array(seq_gt_pil)

                gt_obj_ids = list(np.unique(seq_gt_np, return_counts=False))
                gt_obj_ids.remove(0)

                for obj_id in gt_obj_ids:
                    obj_locations = np.nonzero(seq_gt_np == obj_id)
                    centroid_y, centroid_x = [
                        np.mean(k) for k in obj_locations
                    ]

                    file_id_to_gt[seq_gt_src_file_id][obj_id] = [
                        obj_locations, centroid_y, centroid_x
                    ]

                    if obj_id not in obj_id_to_gt_file_ids:
                        obj_id_to_gt_file_ids[obj_id] = []

                    obj_id_to_gt_file_ids[obj_id].append(seq_gt_src_file_id)

            if write_gt:
                out_gt_path = linux_path(out_gt_root_path, seq_name + '.txt')
                out_gt_fid = open(out_gt_path, 'w')

        file_id_to_seg = OrderedDict()
        file_id_to_nearest_seg = OrderedDict()

        obj_id_to_seg_file_ids = OrderedDict()
        obj_id_to_seg_sizes = OrderedDict()
        obj_id_to_seg_bboxes = OrderedDict()
        obj_id_to_mean_seg_sizes = OrderedDict()
        obj_id_to_max_seg_sizes = OrderedDict()
        all_seg_sizes = []
        mean_seg_sizes = None
        max_seg_sizes = None

        if seg_available:
            print('reading segmentations from {}...'.format(seq_seg_path))
            seq_seq_src_files = [
                k for k in os.listdir(seq_seg_path)
                if os.path.splitext(k.lower())[1] in img_exts
            ]
            for seq_seq_src_file in tqdm(seq_seq_src_files,
                                         disable=_params.disable_tqdm):

                seq_seq_src_file_id = ''.join(k for k in seq_seq_src_file
                                              if k.isdigit())
                file_gt = file_id_to_gt[seq_seq_src_file_id]

                file_id_to_seg[seq_seq_src_file_id] = OrderedDict()

                seq_seq_src_path = os.path.join(seq_seg_path, seq_seq_src_file)
                seq_seg_pil = Image.open(seq_seq_src_path)
                seq_seg_np = np.array(seq_seg_pil)

                seg_obj_ids = list(np.unique(seq_seg_np, return_counts=False))
                seg_obj_ids.remove(0)

                _gt_obj_ids = list(file_gt.keys())

                if len(_gt_obj_ids) != len(seg_obj_ids):
                    print(
                        "\nmismatch between the number of objects in segmentation: {} and GT: {} in {}"
                        .format(len(seg_obj_ids), len(_gt_obj_ids),
                                seq_seq_src_file))

                from scipy.spatial import distance_matrix

                seg_centroids = []
                seg_id_to_locations = {}

                for seg_obj_id in seg_obj_ids:
                    # obj_id = gt_obj_ids[seg_obj_id - 1]

                    seg_obj_locations = np.nonzero(seq_seg_np == seg_obj_id)
                    seg_centroid_y, seg_centroid_x = [
                        np.mean(k) for k in seg_obj_locations
                    ]

                    seg_centroids.append([seg_centroid_y, seg_centroid_x])

                    seg_id_to_locations[seg_obj_id] = seg_obj_locations

                gt_centroids = [[k[1], k[2]] for k in file_gt.values()]
                gt_centroids = np.asarray(gt_centroids)
                seg_centroids = np.asarray(seg_centroids)

                gt_to_seg_dists = distance_matrix(seg_centroids, gt_centroids)

                # seg_min_dist_ids = np.argmin(gt_to_seg_dists, axis=1)
                # gt_min_dist_ids = np.argmin(gt_to_seg_dists, axis=0)
                # unique_min_dist_ids = np.unique(seg_min_dist_ids)                #
                #
                # assert len(unique_min_dist_ids) == len(seg_min_dist_ids), \
                #     "duplicate matches found between segmentation and GT objects"

                # seg_to_gt_obj_ids = {
                #     seg_obj_id:  _gt_obj_ids[seg_min_dist_ids[_id]] for _id, seg_obj_id in enumerate(seg_obj_ids)
                # }

                from scipy.optimize import linear_sum_assignment

                seg_inds, gt_inds = linear_sum_assignment(gt_to_seg_dists)

                if len(seg_inds) != len(seg_obj_ids):
                    print(
                        "only {} / {} segmentation objects assigned to GT objects"
                        .format(len(seg_inds), len(seg_obj_ids)))

                seg_to_gt_obj_ids = {
                    seg_obj_ids[seg_inds[i]]: _gt_obj_ids[gt_inds[i]]
                    for i in range(len(seg_inds))
                }

                # print()

                for seg_obj_id in seg_obj_ids:
                    seg_obj_locations = seg_id_to_locations[seg_obj_id]
                    _gt_obj_id = seg_to_gt_obj_ids[seg_obj_id]

                    min_y, min_x = [np.amin(k) for k in seg_obj_locations]
                    max_y, max_x = [np.amax(k) for k in seg_obj_locations]

                    size_x, size_y = max_x - min_x, max_y - min_y

                    if _gt_obj_id not in obj_id_to_seg_sizes:
                        obj_id_to_seg_bboxes[_gt_obj_id] = []
                        obj_id_to_seg_sizes[_gt_obj_id] = []
                        obj_id_to_seg_file_ids[_gt_obj_id] = []

                    obj_id_to_seg_file_ids[_gt_obj_id].append(
                        seq_seq_src_file_id)

                    file_id_to_seg[seq_seq_src_file_id][_gt_obj_id] = (
                        seg_obj_locations, [min_x, min_y, max_x, max_y])

                    obj_id_to_seg_bboxes[_gt_obj_id].append(
                        [seq_seq_src_file, min_x, min_y, max_x, max_y])
                    obj_id_to_seg_sizes[_gt_obj_id].append([size_x, size_y])

                    all_seg_sizes.append([size_x, size_y])

            obj_id_to_mean_seg_sizes = OrderedDict({
                k: np.mean(v, axis=0)
                for k, v in obj_id_to_seg_sizes.items()
            })
            obj_id_to_max_seg_sizes = OrderedDict({
                k: np.amax(v, axis=0)
                for k, v in obj_id_to_seg_sizes.items()
            })

            print('segmentations found for {} files'.format(
                len(file_id_to_seg),
                # '\n'.join(file_id_to_seg.keys())
            ))
            print('segmentations include {} objects:\n{}'.format(
                len(obj_id_to_seg_bboxes),
                ', '.join(str(k) for k in obj_id_to_seg_bboxes.keys())))

            mean_seg_sizes = np.mean(all_seg_sizes, axis=0)
            max_seg_sizes = np.amax(all_seg_sizes, axis=0)

            for obj_id in obj_id_to_seg_file_ids:
                seg_file_ids = obj_id_to_seg_file_ids[obj_id]
                seg_file_ids_num = np.asarray(
                    list(int(k) for k in seg_file_ids))
                gt_file_ids = obj_id_to_gt_file_ids[obj_id]
                # gt_file_ids_num = np.asarray(list(int(k) for k in gt_file_ids))
                gt_seg_file_ids_dist = {
                    gt_file_id: np.abs(int(gt_file_id) - seg_file_ids_num)
                    for gt_file_id in gt_file_ids
                }
                file_id_to_nearest_seg[obj_id] = {
                    gt_file_id: seg_file_ids[np.argmin(_dist).item()]
                    for gt_file_id, _dist in gt_seg_file_ids_dist.items()
                }

        nearest_seg_size = OrderedDict()

        for frame_id in tqdm(range(n_frames), disable=_params.disable_tqdm):
            seq_img_src_file = seq_img_src_files[frame_id]

            # assert seq_img_src_file in seq_gt_src_file, \
            #     "mismatch between seq_img_src_file and seq_gt_src_file"

            seq_img_src_file_id = ''.join(k for k in seq_img_src_file
                                          if k.isdigit())

            seq_img_src_path = os.path.join(seq_img_path, seq_img_src_file)

            # seq_img_pil = Image.open(seq_img_src_path)
            # seq_img = np.array(seq_img_pil)

            seq_img = cv2.imread(seq_img_src_path, cv2.IMREAD_UNCHANGED)
            # assert (seq_img == seq_img_cv).all(), "mismatch between PIL and cv2 arrays"
            # seq_img_cv_unique = np.unique(seq_img_cv)
            # n_seq_img_unique_cv = len(seq_img_cv_unique)

            # seq_img_float = seq_img.astype(np.float32) / 65535.

            max_pix, min_pix = np.amax(seq_img), np.amin(seq_img)

            seq_img_float_norm = (seq_img.astype(np.float32) -
                                  min_pix) / (max_pix - min_pix)
            seq_img_uint8 = (seq_img_float_norm * 255.).astype(np.uint8)
            # seq_img_uint8 = (seq_img / 256.).astype(np.uint8)
            max_pix_uint8, min_pix_uint8 = np.amax(seq_img_uint8), np.amin(
                seq_img_uint8)

            seq_img_unique = np.unique(seq_img)
            seq_img_unique_uint8 = np.unique(seq_img_uint8)

            n_seq_img_unique = len(seq_img_unique)
            n_seq_img_unique_uint8 = len(seq_img_unique_uint8)

            if n_seq_img_unique > n_seq_img_unique_uint8:
                # print('{} :: drop in number of unique values from {} ({}, {}) to {} ({}, {})'.format(
                #     seq_img_src_file, n_seq_img_unique, max_pix, min_pix,
                #     n_seq_img_unique_uint8, max_pix_uint8, min_pix_uint8))
                with open(log_path, 'a') as log_fid:
                    log_fid.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                        seq_name, seq_img_src_file, n_seq_img_unique,
                        n_seq_img_unique_uint8, max_pix, min_pix,
                        max_pix_uint8, min_pix_uint8))
                # print()

            if write_img:
                # out_img_file = os.path.splitext(seq_img_src_file)[0] + '.png'
                # out_img_file_path = linux_path(out_img_tif_dir_path, out_img_file)

                out_img_file_tif = os.path.splitext(
                    seq_img_src_file)[0] + '.tif'
                out_img_file_path_tif = linux_path(out_img_tif_dir_path,
                                                   out_img_file_tif)

                out_img_file_uint8 = os.path.splitext(
                    seq_img_src_file)[0] + '.jpg'
                out_img_file_path_uint8 = linux_path(out_img_jpg_dir_path,
                                                     out_img_file_uint8)

                if not os.path.exists(out_img_file_path_uint8):
                    # cv2.imwrite(out_img_file_path, seq_img)
                    cv2.imwrite(out_img_file_path_uint8, seq_img_uint8)

                if not os.path.exists(out_img_file_path_tif):
                    # print('{} --> {}'.format(seq_img_src_path, out_img_file_path_tif))
                    shutil.copyfile(seq_img_src_path, out_img_file_path_tif)

            if show_img:
                seq_img_col = seq_img_uint8.copy()
                if len(seq_img_col.shape) == 2:
                    seq_img_col = cv2.cvtColor(seq_img_col, cv2.COLOR_GRAY2BGR)

                # seq_img_col2 = seq_img_col.copy()

                seq_img_col3 = seq_img_col.copy()

            if not gt_available:
                if save_img:
                    if vid_out is not None:
                        vid_out.write(seq_img_col)
                continue

            file_gt = file_id_to_gt[seq_img_src_file_id]

            # seq_gt_src_file = seq_gt_src_files[frame_id]
            # assert seq_gt_src_file_id == seq_img_src_file_id, \
            #     "Mismatch between seq_gt_src_file_id and seq_img_src_file_id"

            gt_obj_ids = list(file_gt.keys())
            for obj_id in gt_obj_ids:
                assert obj_id != 0, "invalid object ID"

                try:
                    nearest_seg_file_ids = file_id_to_nearest_seg[obj_id]
                except KeyError:
                    file_seg = {}
                else:
                    nearest_seg_file_id = nearest_seg_file_ids[
                        seq_img_src_file_id]
                    file_seg = file_id_to_seg[nearest_seg_file_id]

                obj_locations, centroid_y, centroid_x = file_gt[obj_id]

                if file_seg:
                    xmin, ymin, xmax, ymax = file_seg[obj_id][1]
                    size_x, size_y = xmax - xmin, ymax - ymin
                    nearest_seg_size[obj_id] = (size_x, size_y)
                else:
                    try:
                        size_x, size_y = nearest_seg_size[obj_id]
                    except KeyError:
                        try:
                            size_x, size_y = obj_id_to_mean_seg_sizes[obj_id]
                        except KeyError:
                            if mean_seg_sizes is not None:
                                size_x, size_y = mean_seg_sizes
                            else:
                                size_x, size_y = default_obj_size, default_obj_size

                    ymin, xmin = centroid_y - size_y / 2.0, centroid_x - size_x / 2.0
                    ymax, xmax = centroid_y + size_y / 2.0, centroid_x + size_x / 2.0

                width = int(xmax - xmin)
                height = int(ymax - ymin)

                if show_img:
                    col_id = (obj_id - 1) % len(ann_cols)

                    col = col_rgb[ann_cols[col_id]]
                    drawBox(seq_img_col,
                            xmin,
                            ymin,
                            xmax,
                            ymax,
                            label=str(obj_id),
                            box_color=col)
                    # seq_img_col2[obj_locations] = col

                    if file_seg:
                        try:
                            locations, bbox = file_seg[obj_id][:2]
                        except KeyError:
                            print('weird stuff going on here')
                        else:
                            seq_img_col3[locations] = col
                            # min_x, min_y, max_x, max_y = bbox
                            # drawBox(seq_img_col3, min_x, min_y, max_x, max_y, label=str(obj_id), box_color=col)

                if write_gt:
                    out_gt_fid.write(
                        '{:d},{:d},{:.3f},{:.3f},{:d},{:d},1,-1,-1,-1\n'.
                        format(frame_id + 1, obj_id, xmin, ymin, width,
                               height))
                # print()
            skip_seq = 0

            if show_img:
                # images_to_stack = [seq_img_col, seq_img_col2]
                images_to_stack = [
                    seq_img_col,
                ]
                if file_seg:
                    images_to_stack.append(seq_img_col3)
                    __pause = _pause
                else:
                    __pause = _pause

                seq_img_vis = stackImages(images_to_stack, sep_size=5)

                seq_img_vis = resizeAR(seq_img_vis,
                                       height=vis_height,
                                       width=vis_width)

                cv2.putText(seq_img_vis,
                            '{}: {}'.format(seq_name,
                                            seq_img_src_file_id), (20, 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

                if show_img != 2:
                    cv2.imshow('seq_img_vis', seq_img_vis)
                    k = cv2.waitKey(1 - __pause)
                    if k == 32:
                        _pause = 1 - _pause
                    elif k == 27:
                        skip_seq = 1
                        break
                    elif k == ord('q'):
                        break

                if save_img:
                    if vid_out is not None:
                        vid_out.write(seq_img_vis)
                    else:
                        out_vis_file = os.path.splitext(
                            seq_img_src_file)[0] + '.jpg'
                        out_vis_file_path = linux_path(out_vis_path,
                                                       out_vis_file)
                        cv2.imwrite(out_vis_file_path, seq_img_vis)

            if skip_seq or _exit:
                break
        if vid_out is not None:
            vid_out.release()

        if out_gt_fid is not None:
            out_gt_fid.close()

        if _exit:
            break

    n_frames_out_fid.close()
Esempio n. 12
0
def main():
    params = Params()
    paramparse.process(params)
Esempio n. 13
0
def main():
    params = A10_Params()
    train_params = params.train

    # optional command line argument parsing
    try:
        import paramparse
    except ImportError:
        pass
    else:
        paramparse.process(train_params)

    # init device
    if params.use_cuda and torch.cuda.is_available():
        device = torch.device("cuda")
        print('Training on GPU: {}'.format(torch.cuda.get_device_name(0)))
    else:
        device = torch.device("cpu")
        print('Training on CPU')

    # load dataset
    train_set = FontsDataset()

    num_train = len(train_set)
    indices = list(range(num_train))

    assert train_params.valid_ratio > 0, "Zero validation ratio is not allowed "
    split = int(np.floor((1.0 - train_params.valid_ratio) * num_train))

    train_idx, valid_idx = indices[:split], indices[split:]

    print('Training samples: {}\n'
          'Validation samples: {}\n'
          ''.format(
              len(train_idx),
              len(valid_idx),
          ))
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_dataloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=train_params.batch_size,
        sampler=train_sampler,
        num_workers=4)
    valid_dataloader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=24,
                                                   sampler=valid_sampler,
                                                   num_workers=4)

    # create modules
    classifier = Classifier().to(device)

    assert isinstance(classifier,
                      nn.Module), 'classifier must be an instance of nn.Module'

    classifier.init_weights()

    # create losses
    criterion = torch.nn.CrossEntropyLoss().to(device)

    parameters = classifier.parameters()

    # create optimizer
    if train_params.optim_type == 0:
        optimizer = torch.optim.SGD(parameters,
                                    lr=train_params.lr,
                                    momentum=train_params.momentum,
                                    weight_decay=train_params.weight_decay)
    elif train_params.optim_type == 1:
        optimizer = torch.optim.Adam(
            parameters,
            lr=train_params.lr,
            weight_decay=train_params.weight_decay,
            eps=train_params.eps,
        )
    else:
        raise IOError('Invalid optim_type: {}'.format(train_params.optim_type))

    # optimizer = torch.optim.Adam(classifier.parameters())

    weights_dir = os.path.dirname(train_params.weights_path)
    weights_name = os.path.basename(train_params.weights_path)

    if not os.path.isdir(weights_dir):
        os.makedirs(weights_dir)

    tb_path = os.path.join(weights_dir, 'tb')
    if not os.path.isdir(tb_path):
        os.makedirs(tb_path)
    writer = SummaryWriter(logdir=tb_path)

    print(f'Saving tensorboard summary to: {tb_path}')

    start_epoch = 0
    max_valid_acc_epoch = 0
    max_valid_acc = 0
    max_train_acc = 0
    min_valid_loss = np.inf
    min_train_loss = np.inf
    valid_loss = valid_acc = -1

    # load weights
    if train_params.load_weights:
        matching_ckpts = [
            k for k in os.listdir(weights_dir)
            if os.path.isfile(os.path.join(weights_dir, k))
            and k.startswith(weights_name)
        ]
        if not matching_ckpts:
            msg = 'No checkpoints found matching {} in {}'.format(
                weights_name, weights_dir)
            if train_params.load_weights == 1:
                raise IOError(msg)
            print(msg)
        else:
            matching_ckpts.sort(
                key=lambda x:
                [int(c) if c.isdigit() else c for c in re.split(r'(\d+)', x)])

            weights_path = os.path.join(weights_dir, matching_ckpts[-1])

            chkpt = torch.load(weights_path,
                               map_location=device)  # load checkpoint

            print('Loading weights from: {} with:\n'
                  '\tepoch: {}\n'
                  '\ttrain_loss: {}\n'
                  '\ttrain_acc: {}\n'
                  '\tvalid_loss: {}\n'
                  '\tvalid_acc: {}\n'
                  '\ttimestamp: {}\n'.format(weights_path, chkpt['epoch'],
                                             chkpt['train_loss'],
                                             chkpt['train_acc'],
                                             chkpt['valid_loss'],
                                             chkpt['valid_acc'],
                                             chkpt['timestamp']))

            classifier.load_state_dict(chkpt['classifier'])
            optimizer.load_state_dict(chkpt['optimizer'])

            max_valid_acc = chkpt['valid_acc']
            min_valid_loss = chkpt['valid_loss']

            max_train_acc = chkpt['train_acc']
            min_train_loss = chkpt['train_loss']

            max_valid_acc_epoch = chkpt['epoch']
            start_epoch = chkpt['epoch'] + 1

    if train_params.load_weights != 1:
        # start / continue training
        for epoch in range(start_epoch, train_params.n_epochs):
            # set CNN to training mode
            classifier.train()

            train_loss = 0
            train_total = 0
            train_correct = 0
            batch_idx = 0

            save_weights = 0

            for batch_idx, (inputs,
                            targets) in tqdm(enumerate(train_dataloader)):
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()

                outputs = classifier(inputs)

                loss = criterion(outputs, targets)

                mean_loss = loss.item()
                train_loss += mean_loss

                loss.backward()
                optimizer.step()

                _, predicted = outputs.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()

            mean_train_loss = train_loss / (batch_idx + 1)

            train_acc = 100. * train_correct / train_total

            # write training data for tensorboard
            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('train_acc', train_acc, epoch)

            if epoch % train_params.valid_gap == 0:

                valid_loss, valid_acc, _ = evaluate(classifier,
                                                    valid_dataloader,
                                                    criterion,
                                                    train_params.vis, device)

                if valid_acc > max_valid_acc:
                    max_valid_acc = valid_acc
                    max_valid_acc_epoch = epoch
                    if train_params.save_criterion == 0:
                        save_weights = 1

                if valid_loss < min_valid_loss:
                    min_valid_loss = valid_loss
                    if train_params.save_criterion == 1:
                        save_weights = 1

                if train_acc > max_train_acc:
                    max_train_acc = train_acc
                    if train_params.save_criterion == 2:
                        save_weights = 1

                if train_loss < min_train_loss:
                    min_train_loss = train_loss
                    if train_params.save_criterion == 3:
                        save_weights = 1

                # write validation data for tensorboard
                writer.add_scalar('valid_loss', valid_loss, epoch)
                writer.add_scalar('valid_acc', valid_acc, epoch)

                print('Epoch: %d Train-Loss: %.6f  | Train-Acc: %.3f%% | '
                      'Validation-Loss: %.6f | Validation-Acc: %.3f%% | '
                      'Max Validation-Acc: %.3f%% (epoch: %d)' %
                      (epoch, mean_train_loss, train_acc, valid_loss,
                       valid_acc, max_valid_acc, max_valid_acc_epoch))

            # Save checkpoint.
            if save_weights:
                model_dict = {
                    'classifier': classifier.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'train_loss': mean_train_loss,
                    'train_acc': train_acc,
                    'valid_loss': valid_loss,
                    'valid_acc': valid_acc,
                    'epoch': epoch,
                    'timestamp': datetime.now().strftime("%y/%m/%d %H:%M:%S"),
                }
                weights_path = '{}.{:d}'.format(train_params.weights_path,
                                                epoch)
                print('Saving weights to {}'.format(weights_path))
                torch.save(model_dict, weights_path)

    if params.enable_test:
        test_set = FontsDataset('test_data.npz')
        test_dataloader = torch.utils.data.DataLoader(test_set,
                                                      batch_size=24,
                                                      num_workers=4)
    else:
        test_dataloader = valid_dataloader

    start_t = time.time()
    _, test_acc, n_test = evaluate(classifier, test_dataloader, criterion,
                                   train_params.vis, device)
    end_t = time.time()
    test_time = end_t - start_t
    fps = n_test / test_time

    print('test accuracy: {:.4f}%'.format(test_acc))
    print(
        'test time: {:.4f} sec. with {:d} images ({:.4f} images / sec)'.format(
            test_time, n_test, fps))
Esempio n. 14
0
def main():
    global _pause, _quit
    params = VisParams()
    paramparse.process(params)

    # _args = [k for k in sys.argv[1:] if not k.startswith('vis.')]
    # vis_args = ['--{}'.format(k.replace('vis.', '')) for k in sys.argv[1:] if k.startswith('vis.')]

    # processArguments(_args, params)
    # params = _params

    seq_paths = params.seq_paths
    root_dir = params.root_dir
    csv_paths = params.csv_paths
    csv_root_dir = params.csv_root_dir
    class_names_path = params.class_names_path
    data_type = params.data_type
    n_frames = params.n_frames
    seq_prefix = params.seq_prefix
    n_vis = params.n_vis
    vis_size = params.vis_size
    enable_masks = params.enable_masks
    show_img = params.show_img
    save = params.save
    save_fmt = params.save_fmt
    save_dir = params.save_dir
    labels = params.labels
    grid_size = params.grid_size
    only_boxes = params.only_boxes
    crop_size = params.crop_size

    if crop_size:
        crop_size = tuple([int(x) for x in crop_size.split('x')])
        print('Cropping a region of size {}x{} around the box'.format(
            *crop_size))
    else:
        crop_size = ()

    if grid_size:
        grid_size = [int(k) for k in grid_size.split('x')]
        print('Using a grid size of {}x{}'.format(*grid_size))
    else:
        grid_size = None

    # params = Namespace(**params)

    if vis_size:
        vis_size = [int(x) for x in vis_size.split('x')]

    # get parameters
    # _params = ServerParams()
    # _params.processArguments()

    # print('vis_args: ', _params.vis.__dict__)
    # processArguments2(_params, vis_args)

    # print('_params: ', _params)

    # setup logger
    logging_fmt = '%(levelname)s::%(module)s::%(funcName)s::%(lineno)s :  %(message)s'
    logging_level = logging.INFO
    # logging_level = logging.DEBUG
    # logging_level = PROFILE_LEVEL_NUM
    logging.basicConfig(level=logging_level, format=logging_fmt)
    _logger = logging.getLogger()
    _logger.setLevel(logging.INFO)

    if seq_paths:
        if os.path.isfile(seq_paths):
            seq_paths = [
                x.strip() for x in open(seq_paths).readlines() if x.strip()
            ]
        else:
            seq_paths = seq_paths.split(',')
        if root_dir:
            seq_paths = [os.path.join(root_dir, name) for name in seq_paths]

    elif root_dir:
        seq_paths = [
            os.path.join(root_dir, name) for name in os.listdir(root_dir)
            if os.path.isdir(os.path.join(root_dir, name))
        ]
        seq_paths.sort(key=sortKey)
    else:
        raise IOError('Either seq_paths or root_dir must be provided')

    if csv_paths:
        if os.path.isfile(csv_paths):
            csv_paths = [
                x.strip() for x in open(csv_paths).readlines() if x.strip()
            ]
        else:
            csv_paths = csv_paths.split(',')
        if csv_root_dir:
            csv_paths = [
                os.path.join(csv_root_dir, name) for name in csv_paths
            ]
    elif csv_root_dir:
        csv_paths = [
            os.path.join(csv_root_dir, name)
            for name in os.listdir(csv_root_dir)
            if os.path.isfile(os.path.join(csv_root_dir, name))
            and name.endswith('.csv')
        ]
        csv_paths.sort(key=sortKey)
    else:
        csv_paths = [
            os.path.join(seq_path, data_type + '.csv')
            for seq_path in seq_paths
        ]

    seq_path_ids = []

    if seq_prefix:
        seq_path_ids = [
            _id for _id, seq_path in enumerate(seq_paths)
            if os.path.basename(seq_path).startswith(seq_prefix)
        ]
        seq_paths = [seq_paths[_id] for _id in seq_path_ids]
        csv_paths = [csv_paths[_id] for _id in seq_path_ids]

    n_seq, n_csv = len(seq_paths), len(csv_paths)
    if n_seq != n_csv:
        raise IOError(
            'Mismatch between image {} and annotation {} lengths'.format(
                n_seq, n_csv))

    class_names = open(class_names_path, 'r').readlines()
    class_dict = {x.strip(): i for (i, x) in enumerate(class_names)}
    print('class_dict: ', class_dict)
    print('labels: ', labels)

    if n_vis > 0:
        if save:
            save_fname = '{:s}_{:s}.{:s}'.format(save_dir, getDateTime(),
                                                 save_fmt)
            save_path = os.path.join('log', save_fname)
            writer = ImageWriter(save_path, _logger)
            _logger.info('Saving {:s} image sequence to {:s}'.format(
                save_fmt, save_path))

        if n_seq % n_vis != 0:
            raise AssertionError('n_seq: {} not multiple of n_vis: {}'.format(
                n_seq, n_vis))
        n_groups = int(n_seq / n_vis)
        seq_id = 0
        label = ''
        for i in range(n_groups):

            vis_gen = []
            for j in range(n_vis):
                if labels:
                    label = labels[j]
                vis_gen.append(
                    visualize(params.vis,
                              _logger,
                              seq_paths[seq_id],
                              csv_paths[seq_id],
                              class_dict,
                              n_frames=n_frames,
                              generator_mode=1,
                              enable_masks=enable_masks,
                              label=label,
                              only_boxes=only_boxes,
                              crop_size=crop_size))
                seq_id += 1
            for imgs in zip(*vis_gen):
                # img_stacked = np.hstack(imgs)
                stack_params = {'grid_size': grid_size, 'preserve_order': 1}
                if crop_size:
                    stack_params['annotations'] = labels
                img_stacked = stackImages_ptf(imgs, **stack_params)
                img_stacked = cv2.cvtColor(img_stacked, cv2.COLOR_RGB2BGR)

                if vis_size:
                    img_stacked = resizeAR(img_stacked, vis_size[0],
                                           vis_size[1])
                if save:
                    writer.write(img_stacked)

                if show_img:
                    cv2.imshow('img_stacked', img_stacked)
                    key = cv2.waitKey(1 - _pause) % 256
                    if key == 27:
                        break
                    elif key == ord('q'):
                        _quit = 1
                        break
                    elif key == 32:
                        _pause = 1 - _pause
            if _quit:
                break
    if save:
        writer.release()

    else:
        for i in range(n_seq):
            visualize(params.vis,
                      _logger,
                      seq_paths[i],
                      csv_paths[i],
                      class_dict,
                      n_frames=n_frames,
                      enable_masks=enable_masks)
            if _quit:
                break
Esempio n. 15
0
def main():
    params = Params()
    paramparse.process(params)

    in_txt = copy_from_clipboard()
    if in_txt is None:
        lines = []
    else:
        lines = in_txt.split('\n')
        lines = [line.strip() for line in lines if line.strip()]

    if lines:
        is_path = all(
            line.startswith('"') and line.endswith('"') for line in lines)
        if is_path:
            stripped_lines = [line.strip('"') for line in lines]
        else:
            stripped_lines = lines[:]

        try:
            is_ogg = all(line.endswith('.ogg') for line in stripped_lines)
            is_folder = all(os.path.isdir(line) for line in stripped_lines)
        except BaseException as e:
            print('exception during ogg / folder check : {}'.format(e))
            is_ogg = 0
            is_folder = 0

        print('is_ogg: {}'.format(is_ogg))
        print('stripped_lines: {}'.format(stripped_lines))
        print('lines: {}'.format(lines))
        #
        # input('press any key')

        if is_ogg:
            process_ogg(stripped_lines,
                        lines,
                        params.category,
                        is_path,
                        params.cmd,
                        pause_for_input=1)
            return
        elif is_folder:
            stripped_lines.sort()
            for folder in stripped_lines:
                ogg_paths = [
                    '{}'.format(os.path.join(folder, k))
                    for k in os.listdir(folder) if k.endswith('.ogg')
                ]
                ogg_lines = ['"{}"'.format(k) for k in ogg_paths]
                process_ogg(ogg_paths,
                            ogg_lines,
                            params.category,
                            is_path,
                            params.cmd,
                            pause_for_input=0)
            return
        else:
            try:
                out_txt = process(in_txt)
            except:
                pass
            else:
                copy_to_clipboard(out_txt, print_txt=1)
                time.sleep(0.5)
                return

    if params.ffs.enable:
        for _ffs_file in params.ffs.files:
            ffs_path = os.path.join(params.ffs.root,
                                    _ffs_file + '.' + params.ffs.ext)
            ffs_cmd = '{} "{}"'.format(params.ffs.exe, ffs_path)
            print(ffs_cmd)
            os.system(ffs_cmd)

    if params.txt_path:

        assert os.path.isdir(params.txt_path), "invalid text path: {}".format(
            params.txt_path)

        if params.recursive:
            files_gen = [[
                linux_path(dirpath, f) for f in filenames
                if f.endswith('.txt') and f.startswith('Timing')
            ] for (dirpath, dirnames,
                   filenames) in os.walk(params.txt_path, followlinks=True)]
            files = [item for sublist in files_gen for item in sublist]
        else:
            files = os.listdir(params.txt_path)
            files = [
                linux_path(params.txt_path, k) for k in files
                if k and k.endswith('.txt')
            ]

        txt_proc_list_path = linux_path(params.txt_path, params.txt_proc_list)

        if os.path.isfile(txt_proc_list_path):
            processed_files = open(txt_proc_list_path, 'r').readlines()
            processed_files = [
                k.strip().split('\t')[1] for k in processed_files if k.strip()
            ]

            files = [k for k in files if k not in processed_files]

        files.sort(key=os.path.getmtime)
        n_files = len(files)

        if n_files > 0:
            _ = input('\nfound {} new files:\n{}\nPress any key to continue\n'.
                      format(n_files, files))
        else:
            _ = input('\nfound no new files. Press any key to exit\n')

        for file_id, file in enumerate(files[::-1]):
            if file_id > 0:
                _ = input('\nDone {} / {}. Press any key to continue\n'.format(
                    file_id, n_files))

            print('reading file {} / {}: {}'.format(file_id + 1, n_files,
                                                    file))

            # file = dst_file

            in_txt = open(file, 'r').read()
            out_txt = process(in_txt, verbose=0)
            print(out_txt)

            copy_to_clipboard(out_txt)
            time.sleep(0.5)

            out_txt_lines = [k for k in out_txt.split('\n') if k]

            n_out_txt_lines = len(out_txt_lines)

            # dst_file = file.replace('.txt', '.log')
            # shutil.move(file, dst_file)

            print('out_txt_lines: {}'.format(out_txt_lines))
            print('n_out_txt_lines: {}'.format(n_out_txt_lines))

            if n_out_txt_lines == 1:
                os.system("vscode {}".format(file))

            with open(txt_proc_list_path, 'r+') as f:
                content = f.read()
                f.seek(0, 0)

                timestamp_str = datetime.now().strftime(
                    "%y%m%d %H:%M:%S.%f")[:-4]

                txt = '{}\t{}\n'.format(timestamp_str, file)
                f.write(txt + content)

        return

    # time.sleep(1)
    try:
        orig_x, orig_y = win32api.GetCursorPos()
        print('GetCursorPos x: {}'.format(orig_x))
        print('GetCursorPos y: {}'.format(orig_y))

        win32gui.EnumWindows(foreach_window, None)

        # for i in range(len(titles)):
        #     print(titles[i])

        target_title = [
            k[1] for k in titles
            if all(title in k[1] for title in params.win_titles)
        ]
        # print('target_title: {}'.format(target_title))

        if not target_title:
            raise IOError('Window with win_titles: {} not found'.format(
                params.win_titles))

        target_title = target_title[0]

        target_handle = win32gui.FindWindow(None, target_title)
        rect = win32gui.GetWindowRect(target_handle)

        x = int((rect[0] + rect[2]) / 2)
        y = int((rect[1] + rect[3]) / 2)

        # active_handle = win32gui.GetForegroundWindow()
        # target_title = win32gui.GetWindowText(active_handle)

        print('target_title: {}'.format(target_title))
        print('rect: {}'.format(rect))
        print('x: {}'.format(x))
        print('y: {}'.format(y))

        try:
            app = application.Application().connect(title=target_title,
                                                    found_index=0)
        except BaseException as e:
            print('Failed to connect to app for window {}: {}'.format(
                target_title, e))
            exit(0)
        try:
            app_win = app.window(title=target_title)
        except BaseException as e:
            print('Failed to access app window for {}: {}'.format(
                target_title, e))
            exit(0)
        app_win.type_keys("^a")
        app_win.type_keys("^c")

        mouse.move(coords=(x, y))

        mouse.click(button='left', coords=(x, y))
        mouse.click(button='left', coords=(x, y))

        mouse.move(coords=(orig_x, orig_y))

    except BaseException as e:
        print('BaseException: {}'.format(e))

    in_txt = copy_from_clipboard()
    out_txt = process(in_txt)

    # with open(out_fname, 'w') as out_fid:
    #     out_fid.write(out_txt)

    copy_to_clipboard(out_txt, print_txt=1)
Esempio n. 16
0
def main():
    params = Params()
    paramparse.process(params)

    _convert_dataset(params)
Esempio n. 17
0
def main():
    params = A9_Params()

    # optional command line argument parsing
    try:
        import paramparse
    except ImportError:
        pass
    else:
        paramparse.process(params)

    # init device
    if params.use_cuda and torch.cuda.is_available():
        device = torch.device("cuda")
        print('Training on GPU: {}'.format(torch.cuda.get_device_name(0)))
    else:
        device = torch.device("cpu")
        print('Training on CPU')

    # load dataset
    if params.dataset == 0:
        print('Using MNIST dataset')
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,)),
                                        ])
        train_set = datasets.MNIST('data', train=True, download=True, transform=transform)
        test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
        valid_set = datasets.MNIST('data', train=True, download=True, transform=transform)
        train_params = params.mnist
    elif params.dataset == 1:
        print('Using Fashion MNIST dataset')
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,)),
                                        ])
        train_set = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
        test_set = datasets.FashionMNIST('data', train=False, download=True, transform=transform)
        valid_set = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
        train_params = params.fmnist
    else:
        raise IOError('Invalid db_type: {}'.format(params.dataset))

    num_train = len(train_set)
    indices = list(range(num_train))
    split = int(np.floor(params.train_split * num_train))

    train_idx, valid_idx = indices[:split], indices[split:]
    train_set = PartiallyLabeled(train_set, train_idx, labeled_percent=params.labeled_split)

    print('Training samples: {}\n'
          'Validation samples: {}\n'
          'Labeled training samples: {}'
          ''.format(
        len(train_idx),
        len(valid_idx),
        train_set.n_labeled_data
    ))

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SequentialSampler(valid_idx)

    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=train_params.batch_size, sampler=train_sampler,
                                                   num_workers=4)
    valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=24, sampler=valid_sampler,
                                                   num_workers=4)
    test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=10, shuffle=False, num_workers=4)

    # create modules
    encoder = Encoder(device).to(device)
    decoder = Decoder(device).to(device)
    classifier = Classifier(device).to(device)

    assert isinstance(encoder, nn.Module), 'encoder must be an instance of nn.Module'
    assert isinstance(decoder, nn.Module), 'decoder must be an instance of nn.Module'
    assert isinstance(classifier, nn.Module), 'classifier must be an instance of nn.Module'

    modules = nn.ModuleList((encoder, decoder, classifier))

    # init weights
    encoder.init_weights()
    decoder.init_weights(encoder.get_weights())
    classifier.init_weights()

    # create losses
    criterion_rec = torch.nn.MSELoss().to(device)
    criterion_cls = torch.nn.CrossEntropyLoss().to(device)

    parameters = list(modules.parameters())
    if train_params.c0 == 0:
        composite_loss = CompositeLoss(device)
        composite_loss.init_weights()
        assert isinstance(composite_loss, nn.Module), 'composite_loss must be an instance of nn.Module'
        parameters += list(composite_loss.parameters())
    else:
        def composite_loss(x, y):
            return x + train_params.c0 * y

    # create optimizer
    if train_params.optim_type == 0:
        optimizer = torch.optim.SGD(parameters, lr=train_params.lr, momentum=train_params.momentum,
                                    weight_decay=train_params.weight_decay)
    elif train_params.optim_type == 1:
        optimizer = torch.optim.Adam(parameters, lr=train_params.lr, weight_decay=train_params.weight_decay)
    else:
        raise IOError('Invalid optim_type: {}'.format(train_params.optim_type))

    weights_dir = os.path.dirname(train_params.weights_path)
    weights_name = os.path.basename(train_params.weights_path)

    if not os.path.isdir(weights_dir):
        os.makedirs(weights_dir)

    start_epoch = 0
    max_valid_acc_epoch = 0
    max_valid_acc = 0
    max_train_acc = 0
    min_valid_loss = np.inf
    min_train_loss = np.inf

    # load weights
    if train_params.load_weights:
        matching_ckpts = [k for k in os.listdir(weights_dir) if
                          os.path.isfile(os.path.join(weights_dir, k)) and
                          k.startswith(weights_name)]
        if not matching_ckpts:
            msg = 'No checkpoints found matching {} in {}'.format(weights_name, weights_dir)
            if train_params.load_weights == 1:
                raise IOError(msg)
            print(msg)
        else:
            matching_ckpts.sort(key=lambda x: [int(c) if c.isdigit() else c for c in re.split(r'(\d+)', x)])

            weights_path = os.path.join(weights_dir, matching_ckpts[-1])

            chkpt = torch.load(weights_path, map_location=device)  # load checkpoint

            print('Loading weights from: {} with:\n'
                  '\tepoch: {}\n'
                  '\ttrain_loss: {}\n'
                  '\ttrain_acc: {}\n'
                  '\tvalid_loss: {}\n'
                  '\tvalid_acc: {}\n'
                  '\ttimestamp: {}\n'.format(
                weights_path, chkpt['epoch'],
                chkpt['train_loss'], chkpt['train_acc'],
                chkpt['valid_loss'], chkpt['valid_acc'],
                chkpt['timestamp']))

            encoder.load_state_dict(chkpt['encoder'])
            decoder.load_state_dict(chkpt['decoder'])
            classifier.load_state_dict(chkpt['classifier'])
            optimizer.load_state_dict(chkpt['optimizer'])

            if train_params.c0 == 0 and 'composite_loss' in chkpt:
                composite_loss.load_state_dict(chkpt['composite_loss'])

            max_valid_acc = chkpt['valid_acc']
            min_valid_loss = chkpt['valid_loss']

            max_train_acc = chkpt['train_acc']
            min_train_loss = chkpt['train_loss']

            max_valid_acc_epoch = chkpt['epoch']
            start_epoch = chkpt['epoch'] + 1

    if train_params.load_weights != 1:
        # continue training
        for epoch in range(start_epoch, train_params.n_epochs):
            # Training
            modules.train()

            train_loss_rec = 0
            train_loss_cls = 0
            train_loss = 0
            train_total = 0
            train_correct = 0
            batch_idx = 0

            save_weights = 0

            for batch_idx, (inputs, targets, is_labeled) in tqdm(enumerate(train_dataloader)):
                inputs = inputs.to(device)
                targets = targets.to(device)

                if not np.count_nonzero(is_labeled.detach().numpy()):
                    continue

                is_labeled = is_labeled.squeeze().to(device)

                optimizer.zero_grad()

                outputs_enc = encoder(inputs)
                outputs_rec = decoder(outputs_enc)
                outputs_cls = classifier(outputs_enc)

                loss_rec = criterion_rec(outputs_rec, inputs)
                loss_cls = criterion_cls(outputs_cls[is_labeled, :], targets[is_labeled])

                loss = composite_loss(loss_rec, loss_cls)

                mean_loss_rec = loss_rec.item()
                mean_loss_cls = loss_cls.item()
                train_loss_rec += mean_loss_rec
                train_loss_cls += mean_loss_cls

                loss.backward()
                optimizer.step()

                mean_loss = loss.item()
                train_loss += mean_loss

                _, predicted = outputs_cls.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()

            mean_train_loss_rec = train_loss_rec / (batch_idx + 1)
            mean_train_loss_cls = train_loss_cls / (batch_idx + 1)
            mean_train_loss = train_loss / (batch_idx + 1)

            train_acc = 100. * train_correct / train_total

            valid_loss, valid_acc, valid_psnr = eval(
                modules, valid_dataloader, (criterion_rec, criterion_cls), train_params.vis, device)

            if valid_acc > max_valid_acc:
                max_valid_acc = valid_acc
                max_valid_acc_epoch = epoch
                if train_params.save_criterion == 0:
                    save_weights = 1

            if valid_loss < min_valid_loss:
                min_valid_loss = valid_loss
                if train_params.save_criterion == 1:
                    save_weights = 1

            if train_acc > max_train_acc:
                max_train_acc = train_acc
                if train_params.save_criterion == 2:
                    save_weights = 1

            if train_loss < min_train_loss:
                min_train_loss = train_loss
                if train_params.save_criterion == 3:
                    save_weights = 1

            print(
                'Epoch: %d Train-Loss: %.6f (rec: %.6f, cls: %.6f) | Train-Acc: %.3f%% | '
                'Validation-Loss: %.6f | Validation-Acc: %.3f%% | Validation-PSNR: %.3f | '
                'Max Validation-Acc: %.3f%% (epoch: %d)' % (
                    epoch, mean_train_loss, mean_train_loss_rec, mean_train_loss_cls, train_acc,
                    valid_loss, valid_acc, valid_psnr, max_valid_acc, max_valid_acc_epoch))

            # Save checkpoint.
            if save_weights:
                model_dict = {
                    'encoder': encoder.state_dict(),
                    'decoder': decoder.state_dict(),
                    'classifier': classifier.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'train_loss': mean_train_loss,
                    'train_acc': train_acc,
                    'valid_loss': valid_loss,
                    'valid_acc': valid_acc,
                    'epoch': epoch,
                    'timestamp': datetime.now().strftime("%y/%m/%d %H:%M:%S"),
                }
                if train_params.c0 == 0:
                    model_dict['composite_loss'] = composite_loss.state_dict()

                weights_path = '{}.{:d}'.format(train_params.weights_path, epoch)
                print('Saving weights to {}'.format(weights_path))
                torch.save(model_dict, weights_path)

    print('Testing...')
    start_t = time.time()
    test_loss, test_acc, test_psnr = eval(
        modules, test_dataloader, (criterion_rec, criterion_cls), train_params.vis, device)
    end_t = time.time()
    test_time = end_t - start_t

    print('Test-Loss: %.6f | Test-Acc: %.3f%% | Test-PSNR: %.3f%% | Test-Time: %.3f sec' % (
        test_loss, test_acc, test_psnr, test_time))