예제 #1
0
def vis_results(state, steps, *args, immediate=False, **kwargs):
    if not state.get_output_flag():
        logging.warning('Skip visualize results because output_flag is False')
        return

    if isinstance(steps[0][0], torch.Tensor):
        steps = to_np(steps)

    _, _, nc, input_size, _, (mean, std), label_names = datasets.get_info(state)
    dataset_vis_info = (state.dataset, nc, input_size, np.array(mean), np.array(std), label_names)

    vis_args = (steps, state.distilled_images_per_class_per_step, dataset_vis_info, state.arch, state.image_dpi) + args

    if not immediate:
        state.vis_queue.enqueue(_vis_results_fn, *vis_args, **kwargs)
    else:
        _vis_results_fn(*vis_args, **kwargs)
예제 #2
0
    def set_state(self, state, dummy=False):
        if state.opt.sample_n_nets is None:
            state.opt.sample_n_nets = state.opt.n_nets

        base_dir = state.get_base_directory()
        save_dir = state.get_save_directory()

        state.opt.start_time = time.strftime(r"%Y-%m-%d %H:%M:%S")

        # Usually only rank 0 can write to file (except logging, training many
        # nets, etc.) so let's set that flag before everything
        state.opt.distributed = state.world_size > 1
        if state.distributed:
            # read from os.environ
            def set_val_from_environ(key,
                                     save_obj,
                                     ty=str,
                                     fmt="distributed_{}"):
                if key not in os.environ:
                    raise ValueError(
                        "expected environment variable {} to be set when using distributed"
                        .format(key))
                setattr(save_obj, fmt.format(key.lower()), ty(os.environ[key]))

            set_val_from_environ("RANK", state, int, "world_rank")

            state.opt.distributed_file_init = 'INIT_FILE' in os.environ
            if state.opt.distributed_file_init:

                def absolute_path(val):
                    return os.path.abspath(os.path.expanduser(str(val)))

                set_val_from_environ("INIT_FILE", state.opt, ty=absolute_path)
            else:
                os.environ['WORLD_SIZE'] = str(state.world_size)
                set_val_from_environ("MASTER_ADDR", state.opt)
                set_val_from_environ("MASTER_PORT", state.opt, int)

            state.set_output_flag(state.world_rank == 0)
        else:
            state.world_rank = 0
            state.set_output_flag(not dummy)

        if not dummy:
            utils.mkdir(save_dir)

            # First thing: set logging config:
            if not state.opt.no_log:
                log_filename = 'output'
                if state.distributed:
                    log_filename += '_rank{:02}'.format(state.world_rank)
                log_filename += '.log'
                state.opt.log_file = os.path.join(save_dir, log_filename)
            else:
                state.opt.log_file = None

            state.opt.log_level = state.opt.log_level.upper()

            if state.distributed:
                logging_prefix = 'rank {:02d} / {:02d} - '.format(
                    state.world_rank, state.world_size)
            else:
                logging_prefix = ''
            utils.logging.configure(state.opt.log_file,
                                    getattr(logging, state.opt.log_level),
                                    prefix=logging_prefix)

            logging.info("=" * 40 + " " + state.opt.start_time + " " +
                         "=" * 40)
            logging.info('Base directory is {}'.format(base_dir))

            if state.phase == 'test' and not os.path.isdir(base_dir):
                logging.warning("Base directory doesn't exist")

        _, state.opt.dataset_root, state.opt.nc, state.opt.input_size, state.opt.num_classes, \
            state.opt.dataset_normalization, state.opt.dataset_labels = datasets.get_info(state)
        if not state.opt.num_distill_classes:
            state.opt.num_distill_classes = state.opt.num_classes
        if not state.opt.init_labels:
            state.opt.init_labels = list(range(state.opt.num_distill_classes))

        # Write yaml
        yaml_str = yaml.dump(state.merge(public_only=True),
                             default_flow_style=False,
                             indent=4)
        logging.info("Options:\n\t" + yaml_str.replace("\n", "\n\t"))

        if state.get_output_flag():
            yaml_name = os.path.join(save_dir, 'opt.yaml')
            if os.path.isfile(yaml_name):
                old_opt_dir = os.path.join(save_dir, 'old_opts')
                utils.mkdir(old_opt_dir)
                with open(yaml_name, 'r') as f:
                    # ignore unknown ctors
                    yaml.add_multi_constructor(
                        '', lambda loader, suffix, node: None)
                    old_yaml = yaml.load(f)  # this is a dict
                old_yaml_time = old_yaml.get('start_time', 'unknown_time')
                for c in ':-':
                    old_yaml_time = old_yaml_time.replace(c, '_')
                old_yaml_time = old_yaml_time.replace(' ', '__')
                old_opt_new_name = os.path.join(
                    old_opt_dir, 'opt_{}.yaml'.format(old_yaml_time))
                try:
                    os.rename(yaml_name, old_opt_new_name)
                    logging.warning('{} already exists, moved to {}'.format(
                        yaml_name, old_opt_new_name))
                except FileNotFoundError:
                    logging.warning(
                        ('{} already exists, tried to move to {}, but failed, '
                         'possibly due to other process having already done it'
                         ).format(yaml_name, old_opt_new_name))
                    pass

            with open(yaml_name, 'w') as f:
                f.write(yaml_str)

        # FROM HERE, we have saved options into yaml,
        #            can start assigning objects to opt, and
        #            modify the values for process-specific things
        def assert_divided_by_world_size(key, strict=True):
            val = getattr(state, key)
            if strict:
                assert val % state.world_size == 0, \
                    "expected {}={} to be divisible by the world size={}".format(key, val, state.world_size)
                val = val // state.world_size
            else:
                val = math.ceil(val / state.world_size)
            setattr(state, 'local_{}'.format(key), val)

        assert_divided_by_world_size('n_nets')

        if state.mode != 'train':
            assert_divided_by_world_size('test_n_nets')
            assert_divided_by_world_size('sample_n_nets')

        if state.device_id < 0:
            state.opt.device = torch.device("cpu")
        else:
            torch.cuda.set_device(state.device_id)
            state.opt.device = torch.device("cuda:{}".format(state.device_id))

        if not dummy:
            if state.device.type == 'cuda' and torch.backends.cudnn.enabled:
                torch.backends.cudnn.benchmark = True

            seed = state.base_seed
            if state.distributed:
                seed += state.world_rank
                logging.info(
                    "In distributed mode, use arg.seed + rank as seed: {}".
                    format(seed))
            state.opt.seed = seed

            # torch.manual_seed will seed ALL GPUs.
            torch.random.default_generator.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)

        if not dummy and state.distributed:
            logging.info('Initializing distributed process group...')

            if state.distributed_file_init:
                dist.init_process_group("NCCL",
                                        init_method="file://{}".format(
                                            state.distributed_init_file),
                                        rank=state.world_rank,
                                        world_size=state.world_size)
            else:
                dist.init_process_group("NCCL", init_method="env://")

            utils.distributed.barrier()
            logging.info('done!')

            # Check command args consistency across ranks
            # Use a raw parsed dict because we assigned a bunch of things already
            # so this doesn't include things like seed (which can be rank-specific),
            # but includes base_seed.
            opt_dict = vars(self.parser.parse_args())
            opt_dict.pop('device_id')  # don't compare this
            bytes = yaml.dump(opt_dict, encoding='utf-8')
            bytes_storage = torch.ByteStorage.from_buffer(bytes)
            opt_tensor = torch.tensor(
                (), dtype=torch.uint8).set_(bytes_storage).to(state.opt.device)
            for other, ts in enumerate(
                    utils.distributed.all_gather_coalesced([opt_tensor])):
                other_t = ts[0]
                if not torch.equal(other_t, opt_tensor):
                    other_str = bytearray(
                        other_t.cpu().storage().tolist()).decode(
                            encoding="utf-8")
                    this_str = bytes.decode(encoding="utf-8")
                    raise ValueError("Rank {} opt is different from rank {}:\n"
                                     .format(state.world_rank, other) +
                                     utils.diff_str(this_str, other_str))

        # in case of downloading, to avoid race, let rank 0 download.
        if state.world_rank == 0:
            train_dataset = datasets.get_dataset(state, 'train')
            test_dataset = datasets.get_dataset(state, 'test')

        if not dummy and state.distributed:
            utils.distributed.barrier()

        if state.world_rank != 0:
            train_dataset = datasets.get_dataset(state, 'train')
            test_dataset = datasets.get_dataset(state, 'test')

        if state.opt.textdata:
            state.opt.train_loader = data.Iterator(
                train_dataset,
                batch_size=state.batch_size,
                device=state.device,
                repeat=False,
                sort_key=lambda x: len(x.train_dataset),
                shuffle=True)
            state.opt.test_loader = data.Iterator(
                test_dataset,
                batch_size=state.test_batch_size,
                device=state.device,
                repeat=False,
                sort_key=lambda x: len(x.test_dataset),
                shuffle=True)
        else:
            state.opt.train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=state.batch_size,
                num_workers=state.num_workers,
                pin_memory=True,
                shuffle=True)

            state.opt.test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=state.test_batch_size,
                num_workers=state.num_workers,
                pin_memory=True,
                shuffle=True)

        if not dummy:
            logging.info('train dataset size:\t{}'.format(len(train_dataset)))
            logging.info('test dataset size: \t{}'.format(len(test_dataset)))
            logging.info('datasets built!')

            state.vis_queue = utils.multiprocessing.FixSizeProcessQueue(2)
        return state
예제 #3
0
def train(M, src=None, trg=None, saver=None, model_name=None):
    """Main training function

    Creates log file, manages datasets, trains model

    M          - (TensorDict) the model
    src        - (obj) source domain. Contains train/test Data obj
    trg        - (obj) target domain. Contains train/test Data obj
    saver      - (Saver) saves models during training
    model_name - (str) name of the model being run with relevant parms info
    """
    # Training settings
    bs = 64
    iterep = 1000
    n_epoch = 80
    epoch = 0
    feed_dict = {}

    # Create a log directory and FileWriter
    log_dir = os.path.join(args.logdir, model_name)
    delete_existing(log_dir)
    train_writer = tf.summary.FileWriter(log_dir)

    # Create a save directory
    if saver:
        model_dir = os.path.join('checkpoints', model_name)
        delete_existing(model_dir)
        os.makedirs(model_dir)

    if src: get_info('Source mnist', src)
    if trg: get_info('Target svhn', trg)
    print "Batch size:", bs
    print "Iterep:", iterep
    print "Total iterations:", n_epoch * iterep
    print "Log directory:", log_dir

    for i in xrange(n_epoch * iterep):
        # Run main optimizer
        update_dict(M, feed_dict, src, trg, bs)
        summary, _ = M.sess.run(M.ops_main, feed_dict)
        train_writer.add_summary(summary, i + 1)
        train_writer.flush()

        end_epoch, epoch = tb.utils.progbar(i,
                                            iterep,
                                            message='{}/{}'.format(epoch, i),
                                            display=args.run >= 999)

        # Log end-of-epoch values
        if end_epoch:
            print_list = M.sess.run(M.ops_print, feed_dict)

            if src:
                save_acc(M,
                         'fn_ema_acc',
                         'test/src_test_ema_1k',
                         src.test,
                         train_writer,
                         i + 1,
                         print_list,
                         full=False)

            if trg:
                save_acc(M, 'fn_ema_acc', 'test/trg_test_ema', trg.test,
                         train_writer, i + 1, print_list)
                save_acc(M,
                         'fn_ema_acc',
                         'test/trg_train_ema_1k',
                         trg.train,
                         train_writer,
                         i + 1,
                         print_list,
                         full=False)

            print_list += ['epoch', epoch]
            print print_list

        if saver and (i + 1) % 20000 == 0:
            save_model(saver, M, model_dir, i + 1)

    # Saving final model
    if saver:
        save_model(saver, M, model_dir, i + 1)