コード例 #1
0
def learn_graph(args):

    elap = time.time()

    # Do not need to log detailed computation stats
    common.debugger = utils.FakeLogger()

    common.ensure_object_targets(True)

    set_seed(args['seed'])
    task = common.create_env(args['house'], task_name=args['task_name'], false_rate=args['false_rate'],
                             success_measure=args['success_measure'],
                             depth_input=args['depth_input'],
                             target_mask_input=args['target_mask_input'],
                             segment_input=args['segmentation_input'],
                             cacheAllTarget=True,
                             render_device=args['render_gpu'],
                             use_discrete_action=True,
                             include_object_target=True,
                             include_outdoor_target=True,
                             discrete_angle=True)

    # create motion
    __graph_warmstart = args['warmstart']
    args['warmstart'] = args['motion_warmstart']
    motion = create_motion(args, task)

    # create graph
    args['warmstart'] = __graph_warmstart
    graph = GraphPlanner(motion)

    # logger
    logger = utils.MyLogger(args['save_dir'], True)

    logger.print("> Training Mode = {}".format(args['training_mode']))
    logger.print("> Graph Eps = {}".format(args['graph_eps']))
    logger.print("> N_Trials = {}".format(args['n_trials']))
    logger.print("> Max Exploration Steps = {}".format(args['max_exp_steps']))

    # Graph Building
    logger.print('Start Graph Building ...')

    if args['warmstart'] is not None:
        filename = args['warmstart']
        logger.print(' >>> Loading Pre-Trained Graph from {}'.format(filename))
        with open(filename, 'rb') as file:
            g_params = pickle.load(file)
        graph.set_parameters(g_params)

    train_mode = args['training_mode']
    if train_mode in ['mle', 'joint']:
        graph.learn(n_trial=args['n_trials'], max_allowed_steps=args['max_exp_steps'], eps=args['graph_eps'], logger=logger)

    if train_mode in ['evolution', 'joint']:
        graph.evolve()   # TODO: not implemented yet

    logger.print('######## Final Stats ###########')
    graph._show_prior_room(logger=logger)
    graph._show_prior_object(logger=logger)
    return graph
コード例 #2
0
ファイル: zmq_train.py プロジェクト: qiming-zou/HouseNavAgent
def train(args=None, warmstart=None):

    # Process Observation Shape
    common.process_observation_shape(model='rnn',
                                     resolution_level=args['resolution_level'],
                                     segmentation_input=args['segment_input'],
                                     depth_input=args['depth_input'],
                                     target_mask_input=args['target_mask_input'],
                                     history_frame_len=1)

    args['logger'] = utils.MyLogger(args['log_dir'], True, keep_file_handler=not args['append_file'])

    name = 'ipc://@whatever' + args['job_name']
    name2 = 'ipc://@whatever' + args['job_name'] + '2'
    n_proc = args['n_proc']
    config = create_zmq_config(args)
    procs = [ZMQSimulator(k, name, name2, config) for k in range(n_proc)]
    [k.start() for k in procs]
    ensure_proc_terminate(procs)

    trainer = create_zmq_trainer(args['algo'], model='rnn', args=args)
    if warmstart is not None:
        if os.path.exists(warmstart):
            print('Warmstarting from <{}> ...'.format(warmstart))
            trainer.load(warmstart)
        else:
            save_dir = args['save_dir']
            print('Warmstarting from save_dir <{}> with version <{}> ...'.format(save_dir, warmstart))
            trainer.load(save_dir, warmstart)


    master = ZMQMaster(name, name2, trainer=trainer, config=args)

    try:
        # both loops must be running
        print('Start Iterations ....')
        send_thread = threading.Thread(target=master.send_loop, daemon=True)
        send_thread.start()
        master.recv_loop()
        print('Done!')
        trainer.save(args['save_dir'], version='final')
    except KeyboardInterrupt:
        trainer.save_all(args['save_dir'], version='interrupt')
        raise
コード例 #3
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    utils.seed_rng(config['seed'])
    utils.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    G3 = model.Generator(**config).to(device)
    D3 = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    GD3 = model.G_D(G3, D3, config['conditional'])
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config}
    if config['resume']:
        utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume']))
    train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle'])
    utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(select_dataset)
    loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset})
    # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the data set.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.eval()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch==0 and i==0:
                print(config['seed'])
            metrics = train(x, y)
            # We double the learning rate if we double the batch size.
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name)
            experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number)
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
        state_dict['epoch'] += 1
    #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None)
    utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
コード例 #4
0
ファイル: eval.py プロジェクト: qiming-zou/HouseNavAgent
def evaluate_aux_pred(house,
                      seed=0,
                      iters=1000,
                      max_episode_len=10,
                      algo='a3c',
                      model_name='rnn',
                      model_file=None,
                      log_dir='./log/eval',
                      store_history=False,
                      use_batch_norm=True,
                      rnn_units=None,
                      rnn_layers=None,
                      rnn_cell=None,
                      multi_target=True,
                      use_target_gating=False,
                      segmentation_input='none',
                      depth_input=False,
                      resolution='normal'):

    # TODO: currently do not support this
    assert False, 'Aux Prediction Not Supported!'

    # Do not need to log detailed computation stats
    assert algo in ['a3c', 'nop']
    flag_run_random_policy = (algo == 'nop')
    common.debugger = utils.FakeLogger()
    args = common.create_default_args(algo,
                                      model=model_name,
                                      use_batch_norm=use_batch_norm,
                                      replay_buffer_size=50,
                                      episode_len=max_episode_len,
                                      rnn_units=rnn_units,
                                      rnn_layers=rnn_layers,
                                      rnn_cell=rnn_cell,
                                      segmentation_input=segmentation_input,
                                      resolution_level=resolution,
                                      depth_input=depth_input,
                                      history_frame_len=1)
    # TODO: add code for evaluation aux-task (concept learning)
    args['multi_target'] = multi_target
    args['target_gating'] = use_target_gating
    args['aux_task'] = True
    import zmq_train
    set_seed(seed)
    env = common.create_env(house,
                            hardness=1e-8,
                            success_measure='stay',
                            depth_input=depth_input,
                            segment_input=args['segment_input'],
                            genRoomTypeMap=True,
                            cacheAllTarget=True,
                            use_discrete_action=True)
    trainer = zmq_train.create_zmq_trainer(algo, model_name, args)
    if model_file is not None:
        trainer.load(model_file)
    trainer.eval()  # evaluation mode

    logger = utils.MyLogger(log_dir, True)
    logger.print('Start Evaluating Auxiliary Task ...')
    logger.print(
        '  --> Episode (Left) Turning Steps = {}'.format(max_episode_len))
    episode_err = []
    episode_succ = []
    episode_good = []
    episode_rews = []
    episode_stats = []
    elap = time.time()

    for it in range(iters):
        trainer.reset_agent()
        set_seed(seed + it + 1)  # reset seed
        obs = env.reset() if multi_target else env.reset(
            target=env.get_current_target())
        target_id = common.target_instruction_dict[env.get_current_target()]
        if multi_target and hasattr(trainer, 'set_target'):
            trainer.set_target(env.get_current_target())
        cur_infos = []
        if store_history:
            cur_infos.append(proc_info(env.info))
            # cur_images.append(env.render(renderMapLoc=env.cam_info['loc'], display=False))
        if model_name != 'rnn': obs = obs.transpose([1, 0, 2])
        episode_succ.append(0)
        episode_err.append(0)
        episode_good.append(0)
        cur_rew = []
        cur_pred = []
        if flag_run_random_policy:
            predefined_aux_pred = common.all_aux_predictions[random.choice(
                common.all_target_instructions)]
        for _st in range(max_episode_len):
            # get action
            if flag_run_random_policy:
                aux_pred = predefined_aux_pred
            else:
                if multi_target:
                    _, _, aux_prob = trainer.action(obs,
                                                    return_numpy=True,
                                                    target=[[target_id]],
                                                    return_aux_pred=True,
                                                    return_aux_logprob=False)
                else:
                    _, _, aux_prob = trainer.action(obs,
                                                    return_numpy=True,
                                                    return_aux_pred=True,
                                                    return_aux_logprob=False)
                aux_prob = aux_prob.squeeze()  # [n_pred]
                aux_pred = int(
                    np.argmax(aux_prob)
                )  # greedy action, takes the output with the maximum confidence
            aux_rew = trainer.get_aux_task_reward(
                aux_pred, env.get_current_room_pred_mask())
            cur_rew.append(aux_rew)
            cur_pred.append(common.all_aux_prediction_list[aux_pred])
            if aux_rew < 0:
                episode_err[-1] += 1
            if aux_rew >= 0.9:  # currently a hack
                episode_succ[-1] += 1
            if aux_rew > 0:
                episode_good[-1] += 1
            action = 5  # Left Rotation
            # environment step
            obs, rew, done, info = env.step(action)
            if store_history:
                cur_infos.append(proc_info(info))
                cur_infos[-1]['aux_pred'] = cur_pred
                #cur_images.append(env.render(renderMapLoc=env.cam_info['loc'], display=False))
            if model_name != 'rnn': obs = obs.transpose([1, 0, 2])
        if episode_err[-1] > 0:
            episode_succ[-1] = 0
        room_mask = env.get_current_room_pred_mask()
        cur_room_types = []
        for i in range(common.n_aux_predictions):
            if (room_mask & (1 << i)) > 0:
                cur_room_types.append(common.all_aux_prediction_list[i])

        cur_stats = dict(err=episode_err[-1],
                         good=episode_good[-1],
                         succ=episode_succ[-1],
                         rew=cur_rew,
                         err_rate=episode_err[-1] / max_episode_len,
                         good_rate=episode_good[-1] / max_episode_len,
                         succ_rate=episode_succ[-1] / max_episode_len,
                         target=env.get_current_target(),
                         mask=room_mask,
                         room_types=cur_room_types,
                         length=max_episode_len)
        if store_history:
            cur_stats['infos'] = cur_infos
        episode_stats.append(cur_stats)

        dur = time.time() - elap
        logger.print('Episode#%d, Elapsed = %.3f min' % (it + 1, dur / 60))
        logger.print('  ---> Target Room = {}'.format(cur_stats['target']))
        logger.print('  ---> Aux Rew = {}'.format(cur_rew))
        if (episode_succ[-1] > 0) and (episode_err[-1] == 0):
            logger.print('  >>>> Success!')
        elif episode_err[-1] == 0:
            logger.print('  >>>> Good!')
        else:
            logger.print('  >>>> Failed!')
        logger.print(
            "  ---> Indep. Prediction: Succ Rate = %.3f, Good Rate = %.3f, Err Rate = %.3f"
            % (episode_succ[-1] * 100.0 / max_episode_len,
               episode_good[-1] * 100.0 / max_episode_len,
               episode_err[-1] * 100.0 / max_episode_len))
        logger.print(
            "  > Accu. Succ = %.3f, Good = %.3f, Fail = %.3f" %
            (float(np.mean([float(s == max_episode_len)
                            for s in episode_succ])) * 100.0,
             float(np.mean([float(e == 0) for e in episode_err])) * 100,
             float(np.mean([float(e > 0) for e in episode_err])) * 100))
        logger.print(
            "  > Accu. Rate: Succ Rate = %.3f, Good Rate = %.3f, Fail Rate = %.3f"
            % (float(np.mean([s / max_episode_len
                              for s in episode_succ])) * 100.0,
               float(np.mean([g / max_episode_len
                              for g in episode_good])) * 100,
               float(np.mean([e / max_episode_len
                              for e in episode_err])) * 100))
    return episode_stats
コード例 #5
0
ファイル: eval.py プロジェクト: qiming-zou/HouseNavAgent
def evaluate(house,
             seed=0,
             render_device=None,
             iters=1000,
             max_episode_len=1000,
             task_name='roomnav',
             false_rate=0.0,
             hardness=None,
             max_birthplace_steps=None,
             success_measure='center',
             multi_target=False,
             fixed_target=None,
             algo='nop',
             model_name='cnn',
             model_file=None,
             log_dir='./log/eval',
             store_history=False,
             use_batch_norm=True,
             rnn_units=None,
             rnn_layers=None,
             rnn_cell=None,
             use_action_gating=False,
             use_residual_critic=False,
             use_target_gating=False,
             segmentation_input='none',
             depth_input=False,
             target_mask_input=False,
             resolution='normal',
             history_len=4,
             include_object_target=False,
             include_outdoor_target=True,
             aux_task=False,
             no_skip_connect=False,
             feed_forward=False,
             greedy_execution=False,
             greedy_aux_pred=False):

    assert not aux_task, 'Do not support Aux-Task now!'

    elap = time.time()

    # Do not need to log detailed computation stats
    common.debugger = utils.FakeLogger()

    args = common.create_default_args(algo,
                                      model=model_name,
                                      use_batch_norm=use_batch_norm,
                                      replay_buffer_size=50,
                                      episode_len=max_episode_len,
                                      rnn_units=rnn_units,
                                      rnn_layers=rnn_layers,
                                      rnn_cell=rnn_cell,
                                      segmentation_input=segmentation_input,
                                      resolution_level=resolution,
                                      depth_input=depth_input,
                                      target_mask_input=target_mask_input,
                                      history_frame_len=history_len)
    args['action_gating'] = use_action_gating
    args['residual_critic'] = use_residual_critic
    args['multi_target'] = multi_target
    args['object_target'] = include_object_target
    args['target_gating'] = use_target_gating
    args['aux_task'] = aux_task
    args['no_skip_connect'] = no_skip_connect
    args['feed_forward'] = feed_forward
    if (fixed_target is not None) and (fixed_target
                                       not in ['any-room', 'any-object']):
        assert fixed_target in common.n_target_instructions, 'invalid fixed target <{}>'.format(
            fixed_target)

    __backup_CFG = common.CFG.copy()
    if fixed_target == 'any-room':
        common.ensure_object_targets(False)

    if hardness is not None:
        print('>>>> Hardness = {}'.format(hardness))
    if max_birthplace_steps is not None:
        print('>>>> Max BirthPlace Steps = {}'.format(max_birthplace_steps))
    set_seed(seed)
    env = common.create_env(house,
                            task_name=task_name,
                            false_rate=false_rate,
                            hardness=hardness,
                            max_birthplace_steps=max_birthplace_steps,
                            success_measure=success_measure,
                            depth_input=depth_input,
                            target_mask_input=target_mask_input,
                            segment_input=args['segment_input'],
                            genRoomTypeMap=aux_task,
                            cacheAllTarget=multi_target,
                            render_device=render_device,
                            use_discrete_action=('dpg' not in algo),
                            include_object_target=include_object_target
                            and (fixed_target != 'any-room'),
                            include_outdoor_target=include_outdoor_target,
                            discrete_angle=True)

    if (fixed_target is not None) and (fixed_target != 'any-room') and (
            fixed_target != 'any-object'):
        env.reset_target(fixed_target)

    if fixed_target == 'any-room':
        common.CFG = __backup_CFG
        common.ensure_object_targets(True)

    # create model
    if model_name == 'rnn':
        import zmq_train
        trainer = zmq_train.create_zmq_trainer(algo, model_name, args)
    else:
        trainer = common.create_trainer(algo, model_name, args)
    if model_file is not None:
        trainer.load(model_file)
    trainer.eval()  # evaluation mode
    if greedy_execution and hasattr(trainer, 'set_greedy_execution'):
        trainer.set_greedy_execution()
    else:
        print('[Eval] WARNING!!! Greedy Policy Execution NOT Available!!!')
        greedy_execution = False
    if greedy_aux_pred and hasattr(trainer, 'set_greedy_aux_prediction'):
        trainer.set_greedy_aux_prediction()
    else:
        print(
            '[Eval] WARNING!!! Greedy Execution of Auxiliary Task NOT Available!!!'
        )
        greedy_aux_pred = False

    if aux_task: assert trainer.is_rnn()  # only rnn support aux_task

    #flag_random_reset_target = multi_target and (fixed_target is None)

    logger = utils.MyLogger(log_dir, True)
    logger.print('Start Evaluating ...')

    episode_success = []
    episode_good = []
    episode_stats = []
    t = 0
    for it in range(iters):
        cur_infos = []
        trainer.reset_agent()
        set_seed(seed + it + 1)  # reset seed
        obs = env.reset(target=fixed_target)
        #if multi_target and (fixed_target is not None) and (fixed_target != 'kitchen'):
        #    # TODO: Currently a hacky solution
        #    env.reset(target=fixed_target)
        #    if house < 0:  # multi-house env
        #        obs = env.reset(reset_target=False, keep_world=True)
        #    else:
        #        obs = env.reset(reset_target=False)
        #else:
        #    # TODO: Only support multi-target + fixed kitchen; or fixed-target (kitchen)
        #    obs = env.reset(reset_target=flag_random_reset_target)
        target_id = common.target_instruction_dict[env.get_current_target()]
        if multi_target and hasattr(trainer, 'set_target'):
            trainer.set_target(env.get_current_target())
        if store_history:
            cur_infos.append(proc_info(env.info))
            #cur_images.append(env.render(renderMapLoc=env.cam_info['loc'], display=False))
        if model_name != 'rnn': obs = obs.transpose([1, 0, 2])
        episode_success.append(0)
        episode_good.append(0)
        cur_stats = dict(best_dist=1e50,
                         success=0,
                         good=0,
                         reward=0,
                         target=env.get_current_target(),
                         meters=env.info['meters'],
                         optstep=env.info['optsteps'],
                         length=max_episode_len,
                         images=None)
        if aux_task:
            cur_stats['aux_pred_rew'] = 0
            cur_stats['aux_pred_err'] = 0
        if hasattr(env.house, "_id"):
            cur_stats['world_id'] = env.house._id
        episode_step = 0
        for _st in range(max_episode_len):
            # get action
            if trainer.is_rnn():
                idx = 0
                if multi_target:
                    if aux_task:
                        action, _, aux_pred = trainer.action(
                            obs,
                            return_numpy=True,
                            target=[[target_id]],
                            return_aux_pred=True)
                    else:
                        action, _ = trainer.action(obs,
                                                   return_numpy=True,
                                                   target=[[target_id]])
                else:
                    if aux_task:
                        action, _, aux_pred = trainer.action(
                            obs, return_numpy=True, return_aux_pred=True)
                    else:
                        action, _ = trainer.action(obs, return_numpy=True)
                action = action.squeeze()
                if greedy_execution:
                    action = int(np.argmax(action))
                else:
                    action = int(action)
                if aux_task:
                    aux_pred = aux_pred.squeeze()
                    if greedy_aux_pred:
                        aux_pred = int(np.argmax(aux_pred))
                    else:
                        aux_pred = int(aux_pred)
                    aux_rew = trainer.get_aux_task_reward(
                        aux_pred, env.get_current_room_pred_mask())
                    cur_stats['aux_pred_rew'] += aux_rew
                    if aux_rew < 0: cur_stats['aux_pred_err'] += 1
            else:
                idx = trainer.process_observation(obs)
                action = trainer.action(
                    None if greedy_execution else 1.0)  # use gumbel noise
            # environment step
            obs, rew, done, info = env.step(action)
            if store_history:
                cur_infos.append(proc_info(info))
                #cur_images.append(env.render(renderMapLoc=env.cam_info['loc'], display=False))
            if model_name != 'rnn': obs = obs.transpose([1, 0, 2])
            cur_dist = info['dist']
            if cur_dist == 0:
                cur_stats['good'] += 1
                episode_good[-1] = 1
            t += 1
            if cur_dist < cur_stats['best_dist']:
                cur_stats['best_dist'] = cur_dist
            episode_step += 1
            # collect experience
            trainer.process_experience(idx, action, rew, done,
                                       (_st + 1 >= max_episode_len), info)
            if done:
                if rew > 5:  # magic number:
                    episode_success[-1] = 1
                    cur_stats['success'] = 1
                cur_stats['length'] = episode_step
                if aux_task:
                    cur_stats['aux_pred_err'] /= episode_step
                    cur_stats['aux_pred_rew'] /= episode_step
                break
        if store_history:
            cur_stats['infos'] = cur_infos
        episode_stats.append(cur_stats)

        dur = time.time() - elap
        logger.print('Episode#%d, Elapsed = %.3f min' % (it + 1, dur / 60))
        if multi_target:
            logger.print('  ---> Target Room = {}'.format(cur_stats['target']))
        logger.print('  ---> Total Samples = {}'.format(t))
        logger.print('  ---> Success = %d  (rate = %.3f)' %
                     (cur_stats['success'], np.mean(episode_success)))
        logger.print(
            '  ---> Times of Reaching Target Room = %d  (rate = %.3f)' %
            (cur_stats['good'], np.mean(episode_good)))
        logger.print('  ---> Best Distance = %d' % cur_stats['best_dist'])
        logger.print('  ---> Birth-place Distance = %d' % cur_stats['optstep'])
        if aux_task:
            logger.print(
                '    >>>>>> Aux-Task: Avg Rew = %.4f, Avg Err = %.4f' %
                (cur_stats['aux_pred_rew'], cur_stats['aux_pred_err']))

    logger.print('######## Final Stats ###########')
    logger.print('Success Rate = %.3f' % np.mean(episode_success))
    logger.print(
        '> Avg Ep-Length per Success = %.3f' %
        np.mean([s['length'] for s in episode_stats if s['success'] > 0]))
    logger.print(
        '> Avg Birth-Meters per Success = %.3f' %
        np.mean([s['meters'] for s in episode_stats if s['success'] > 0]))
    logger.print('Reaching Target Rate = %.3f' % np.mean(episode_good))
    logger.print('> Avg Ep-Length per Target Reach = %.3f' %
                 np.mean([s['length']
                          for s in episode_stats if s['good'] > 0]))
    logger.print('> Avg Birth-Meters per Target Reach = %.3f' %
                 np.mean([s['meters']
                          for s in episode_stats if s['good'] > 0]))
    if multi_target:
        all_targets = list(set([s['target'] for s in episode_stats]))
        for tar in all_targets:
            n = sum([1.0 for s in episode_stats if s['target'] == tar])
            succ = [
                float(s['success'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            good = [
                float(s['good'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            length = [s['length'] for s in episode_stats if s['target'] == tar]
            meters = [s['meters'] for s in episode_stats if s['target'] == tar]
            good_len = np.mean([l for l, g in zip(length, good) if g > 0.5])
            succ_len = np.mean([l for l, s in zip(length, succ) if s > 0.5])
            good_mts = np.mean([l for l, g in zip(meters, good) if g > 0.5])
            succ_mts = np.mean([l for l, s in zip(meters, succ) if s > 0.5])
            logger.print(
                '>>>>> Multi-Target <%s>: Rate = %.3f (n=%d), Good = %.3f (AvgLen=%.3f; Mts=%.3f), Succ = %.3f (AvgLen=%.3f; Mts=%.3f)'
                % (tar, n / len(episode_stats), n, np.mean(good), good_len,
                   good_mts, np.mean(succ), succ_len, succ_mts))

    if aux_task:
        logger.print(
            ' -->>> Auxiliary-Task: Mean Episode Avg Rew = %.6f, Mean Episode Avg Err = %.6f'
            % (np.mean([float(s['aux_pred_rew']) for s in episode_stats]),
               np.mean([float(s['aux_pred_err']) for s in episode_stats])))

    return episode_stats
コード例 #6
0
def evaluate(args):

    elap = time.time()

    # Do not need to log detailed computation stats
    common.debugger = utils.FakeLogger()

    # ensure observation shape
    common.process_observation_shape(
        'rnn',
        args['resolution'],
        segmentation_input=args['segmentation_input'],
        depth_input=args['depth_input'],
        history_frame_len=1,
        target_mask_input=args['target_mask_input'])

    fixed_target = args['fixed_target']
    if (fixed_target is not None) and (fixed_target != 'any-room') and (
            fixed_target != 'any-object'):
        assert fixed_target in common.all_target_instructions, 'invalid fixed target <{}>'.format(
            fixed_target)

    __backup_CFG = common.CFG.copy()
    if fixed_target == 'any-room':
        common.ensure_object_targets(False)

    if args['hardness'] is not None:
        print('>>>> Hardness = {}'.format(args['hardness']))
    if args['max_birthplace_steps'] is not None:
        print('>>>> Max BirthPlace Steps = {}'.format(
            args['max_birthplace_steps']))
    set_seed(args['seed'])
    task = common.create_env(args['house'],
                             task_name=args['task_name'],
                             false_rate=args['false_rate'],
                             hardness=args['hardness'],
                             max_birthplace_steps=args['max_birthplace_steps'],
                             success_measure=args['success_measure'],
                             depth_input=args['depth_input'],
                             target_mask_input=args['target_mask_input'],
                             segment_input=args['segmentation_input'],
                             genRoomTypeMap=False,
                             cacheAllTarget=args['multi_target'],
                             render_device=args['render_gpu'],
                             use_discrete_action=True,
                             include_object_target=args['object_target']
                             and (fixed_target != 'any-room'),
                             include_outdoor_target=args['outdoor_target'],
                             discrete_angle=True,
                             min_birthplace_grids=args['min_birthplace_grids'])

    if (fixed_target is not None) and (fixed_target != 'any-room') and (
            fixed_target != 'any-object'):
        task.reset_target(fixed_target)

    if fixed_target == 'any-room':
        common.CFG = __backup_CFG
        common.ensure_object_targets(True)

    # create semantic classifier
    if args['semantic_dir'] is not None:
        assert os.path.exists(
            args['semantic_dir']
        ), '[Error] Semantic Dir <{}> not exists!'.format(args['semantic_dir'])
        assert not args[
            'object_target'], '[ERROR] currently do not support --object-target!'
        print('Loading Semantic Oracle from dir <{}>...'.format(
            args['semantic_dir']))
        if args['semantic_gpu'] is None:
            args['semantic_gpu'] = common.get_gpus_for_rendering()[0]
        oracle = SemanticOracle(model_dir=args['semantic_dir'],
                                model_device=args['semantic_gpu'],
                                include_object=args['object_target'])
        oracle_func = OracleFunction(
            oracle,
            threshold=args['semantic_threshold'],
            filter_steps=args['semantic_filter_steps'])
    else:
        oracle_func = None

    # create motion
    motion = create_motion(args, task, oracle_func)

    logger = utils.MyLogger(args['log_dir'], True)
    logger.print('Start Evaluating ...')

    episode_success = []
    episode_good = []
    episode_stats = []
    t = 0
    seed = args['seed']
    max_episode_len = args['max_episode_len']

    plan_req = args['plan_dist_iters'] if 'plan_dist_iters' in args else None

    for it in range(args['max_iters']):
        cur_infos = []
        motion.reset()
        set_seed(seed + it + 1)  # reset seed
        if plan_req is not None:
            while True:
                task.reset(target=fixed_target)
                m = len(task.get_optimal_plan())
                if (m in plan_req) and plan_req[m] > 0:
                    break
            plan_req[m] -= 1
        else:
            task.reset(target=fixed_target)
        info = task.info

        episode_success.append(0)
        episode_good.append(0)
        cur_stats = dict(best_dist=info['dist'],
                         success=0,
                         good=0,
                         reward=0,
                         target=task.get_current_target(),
                         meters=task.info['meters'],
                         optstep=task.info['optsteps'],
                         length=max_episode_len,
                         images=None)
        if hasattr(task.house, "_id"):
            cur_stats['world_id'] = task.house._id

        store_history = args['store_history']
        if store_history:
            cur_infos.append(proc_info(task.info))

        if args['temperature'] is not None:
            ep_data = motion.run(task.get_current_target(),
                                 max_episode_len,
                                 temperature=args['temperature'])
        else:
            ep_data = motion.run(task.get_current_target(), max_episode_len)

        for dat in ep_data:
            info = dat[4]
            if store_history:
                cur_infos.append(proc_info(info))
            cur_dist = info['dist']
            if cur_dist == 0:
                cur_stats['good'] += 1
                episode_good[-1] = 1
            if cur_dist < cur_stats['best_dist']:
                cur_stats['best_dist'] = cur_dist

        episode_step = len(ep_data)
        if ep_data[-1][3]:  # done
            if ep_data[-1][2] > 5:  # magic number:
                episode_success[-1] = 1
                cur_stats['success'] = 1
        cur_stats['length'] = episode_step  # store length

        if store_history:
            cur_stats['infos'] = cur_infos
        episode_stats.append(cur_stats)

        dur = time.time() - elap
        logger.print('Episode#%d, Elapsed = %.3f min' % (it + 1, dur / 60))
        if args['multi_target']:
            logger.print('  ---> Target Room = {}'.format(cur_stats['target']))
        logger.print('  ---> Total Samples = {}'.format(t))
        logger.print('  ---> Success = %d  (rate = %.3f)' %
                     (cur_stats['success'], np.mean(episode_success)))
        logger.print(
            '  ---> Times of Reaching Target Room = %d  (rate = %.3f)' %
            (cur_stats['good'], np.mean(episode_good)))
        logger.print('  ---> Best Distance = %d' % cur_stats['best_dist'])
        logger.print('  ---> Birth-place Distance = %d' % cur_stats['optstep'])

    logger.print('######## Final Stats ###########')
    logger.print('Success Rate = %.3f' % np.mean(episode_success))
    logger.print(
        '> Avg Ep-Length per Success = %.3f' %
        np.mean([s['length'] for s in episode_stats if s['success'] > 0]))
    logger.print(
        '> Avg Birth-Meters per Success = %.3f' %
        np.mean([s['meters'] for s in episode_stats if s['success'] > 0]))
    logger.print('Reaching Target Rate = %.3f' % np.mean(episode_good))
    logger.print('> Avg Ep-Length per Target Reach = %.3f' %
                 np.mean([s['length']
                          for s in episode_stats if s['good'] > 0]))
    logger.print('> Avg Birth-Meters per Target Reach = %.3f' %
                 np.mean([s['meters']
                          for s in episode_stats if s['good'] > 0]))
    if args['multi_target']:
        all_targets = list(set([s['target'] for s in episode_stats]))
        for tar in all_targets:
            n = sum([1.0 for s in episode_stats if s['target'] == tar])
            succ = [
                float(s['success'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            good = [
                float(s['good'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            length = [s['length'] for s in episode_stats if s['target'] == tar]
            meters = [s['meters'] for s in episode_stats if s['target'] == tar]
            good_len = np.mean([l for l, g in zip(length, good) if g > 0.5])
            succ_len = np.mean([l for l, s in zip(length, succ) if s > 0.5])
            good_mts = np.mean([l for l, g in zip(meters, good) if g > 0.5])
            succ_mts = np.mean([l for l, s in zip(meters, succ) if s > 0.5])
            logger.print(
                '>>>>> Multi-Target <%s>: Rate = %.3f (n=%d), Good = %.3f (AvgLen=%.3f; Mts=%.3f), Succ = %.3f (AvgLen=%.3f; Mts=%.3f)'
                % (tar, n / len(episode_stats), n, np.mean(good), good_len,
                   good_mts, np.mean(succ), succ_len, succ_mts))

    return episode_stats
コード例 #7
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    ## *** 新增 resolution 使用 I128_hdf5 数据集, 这里也许需要使用 C10数据集
    config['resolution'] = utils.imsize_dict[config['dataset']]
    ## *** 新增 nclass_dict 加载 I128_hdf5 的类别, 这里也许需要使用 C10的类别 10类
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    ## 加载 GD的 激活函数, 都用Relu, 这里的Relu是小写,不知道是否要改大写R
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]

    ## 从头训练吧,么有历史的参数,不用改,默认的就是
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True

    ## 日志加载,也不用改应该
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    ## 设置初始随机数种子,都为0,*** 需要修改为paddle的设置
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    ## 设置日志根目录,这个应该也不用改
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    ## @@@ 这里不需要更改,直接注释掉,Paddle不一定需要这个设置
    ## 用于加速固定网络结构的参数
    # torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    ## *** !!! 这个方法很酷哦,直接导入BigGan的model,要看一下BigGAN里面的网络结构配置
    model = __import__(config['model'])
    ## 不用改,把一系列配置作为名字放到了实验名称中
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    ## *** 导入参数,需要修改两个方法
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    ## *** 默认不开,可以先不改EMA部分
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    ## C10比较小,G和D这部分也可以暂时不改,使用默认精度
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    ## 把设置完结构G和D打包放入结构模型G_D中
    GD = model.G_D(G, D)
    ## *** 这两个print也许可以删掉,没必要。可能源于继承的nn.Module的一些打印属性
    print(G)
    print(D)
    ## *** 这个parameters也是继承torch的属性
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    ## 初始化统计参数记录表 不用变动
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    ## 暂时不用预训练,所以这一块不用更改
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    ## 暂时不用管,GD 默认不并行
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    ## 日志中心,应该也可以不用管,如果需要就是把IS和FID的结果看看能不能抽出来
    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    ## 这个才是重要的,这个是用来做结果统计的。
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)

    ## *** D的数据加载,加载的过程中,get_data_loaders用到了torchvision的transforms方法
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    ## 准备评价指标,FID和IS的计算流程,可以使用np版本计算,也不用改
    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    ## 准备噪声和随机采样的标签组
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])

    # Prepare a fixed z & y to see individual sample evolution throghout training
    ## *** 有一部分torch的numpy用法,需要更改一下,获得噪声和标签
    ## TODO 获得两份噪声和标签,有社么用意吗?
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])

    ## *** 从Distribution中获得采样的方法,可以选择高斯采样和categorical采样
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    ## *** 实例化GAN_training_function训练流程
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    ## 如果没有指定训练模型,那么就用假训走一下流程Debug
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    ## *** 把函数utils.sample中部分入参事先占掉,定义为新的函数sample
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            ## 这一部分无需翻
            ## !!! loaders[0] 代表了数据采样对象
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            ## *** 继承nn.Module中的train, 对应的是
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            ## *** 把数据和标签放入训练函数里,train本身有很多需要改写
            metrics = train(x, y)
            ## 记录日志,把metrics信息都输入日志
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            ## 记录资格迹的变化日志
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            ## 默认每2000步记录一次结果
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    ## *** module中的方法
                    G.eval()
                    ## 如果采用指数滑动平均
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            ## 默认每5000步测试一次
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
コード例 #8
0
ファイル: eval_HRL.py プロジェクト: qiming-zou/HouseNavAgent
def evaluate(args, data_saver=None):

    args['segment_input'] = args['segmentation_input']

    backup_rate = args['backup_rate']

    elap = time.time()

    # Do not need to log detailed computation stats
    common.debugger = utils.FakeLogger()

    # ensure observation shape
    common.process_observation_shape(
        'rnn',
        args['resolution'],
        args['segmentation_input'],
        args['depth_input'],
        target_mask_input=args['target_mask_input'])

    fixed_target = args['fixed_target']
    if (fixed_target is not None) and (fixed_target != 'any-room') and (
            fixed_target != 'any-object'):
        assert fixed_target in common.n_target_instructions, 'invalid fixed target <{}>'.format(
            fixed_target)

    __backup_CFG = common.CFG.copy()
    if fixed_target == 'any-room':
        common.ensure_object_targets(False)

    if args['hardness'] is not None:
        print('>>>> Hardness = {}'.format(args['hardness']))
    if args['max_birthplace_steps'] is not None:
        print('>>>> Max BirthPlace Steps = {}'.format(
            args['max_birthplace_steps']))
    set_seed(args['seed'])
    task = common.create_env(args['house'],
                             task_name=args['task_name'],
                             false_rate=args['false_rate'],
                             hardness=args['hardness'],
                             max_birthplace_steps=args['max_birthplace_steps'],
                             success_measure=args['success_measure'],
                             depth_input=args['depth_input'],
                             target_mask_input=args['target_mask_input'],
                             segment_input=args['segmentation_input'],
                             genRoomTypeMap=False,
                             cacheAllTarget=args['multi_target'],
                             render_device=args['render_gpu'],
                             use_discrete_action=True,
                             include_object_target=args['object_target']
                             and (fixed_target != 'any-room'),
                             include_outdoor_target=args['outdoor_target'],
                             discrete_angle=True,
                             min_birthplace_grids=args['min_birthplace_grids'])

    if (fixed_target is not None) and (fixed_target != 'any-room') and (
            fixed_target != 'any-object'):
        task.reset_target(fixed_target)

    if fixed_target == 'any-room':
        common.CFG = __backup_CFG
        common.ensure_object_targets(True)

    # logger
    logger = utils.MyLogger(args['log_dir'], True)
    logger.print('Start Evaluating ...')

    # create semantic classifier
    if args['semantic_dir'] is not None:
        assert os.path.exists(
            args['semantic_dir']
        ), '[Error] Semantic Dir <{}> not exists!'.format(args['semantic_dir'])
        assert not args[
            'object_target'], '[ERROR] currently do not support --object-target!'
        print('Loading Semantic Oracle from dir <{}>...'.format(
            args['semantic_dir']))
        if args['semantic_gpu'] is None:
            args['semantic_gpu'] = common.get_gpus_for_rendering()[0]
        oracle = SemanticOracle(model_dir=args['semantic_dir'],
                                model_device=args['semantic_gpu'],
                                include_object=args['object_target'])
        oracle_func = OracleFunction(
            oracle,
            threshold=args['semantic_threshold'],
            filter_steps=args['semantic_filter_steps'],
            batched_size=args['semantic_batch_size'])
    else:
        oracle_func = None

    # create motion
    motion = create_motion(args, task, oracle_func=oracle_func)
    if args['motion'] == 'random':
        motion.set_skilled_rate(args['random_motion_skill'])
    flag_interrupt = args['interruptive_motion']

    # create planner
    graph = None
    max_motion_steps = args['n_exp_steps']
    if (args['planner'] == None) or (args['planner'] == 'void'):
        graph = VoidPlanner(motion)
    elif args['planner'] == 'oracle':
        graph = OraclePlanner(motion)
    elif args['planner'] == 'rnn':
        #assert False, 'Currently only support Graph-planner'
        graph = RNNPlanner(motion,
                           args['planner_units'],
                           args['planner_filename'],
                           oracle_func=oracle_func)
    else:
        graph = GraphPlanner(motion)
        if not args['outdoor_target']:
            graph.add_excluded_target('outdoor')
        filename = args['planner_filename']
        if filename == 'None': filename = None
        if filename is not None:
            logger.print(' > Loading Graph from file = <{}>'.format(filename))
            with open(filename, 'rb') as f:
                _params = pickle.load(f)
            graph.set_parameters(_params)
        # hack
        if args['planner_obs_noise'] is not None:
            graph.set_param(-1, args['planner_obs_noise'])  # default 0.95

    episode_success = []
    episode_good = []
    episode_stats = []
    t = 0
    seed = args['seed']
    max_episode_len = args['max_episode_len']

    plan_req = args['plan_dist_iters'] if 'plan_dist_iters' in args else None

    ####################
    accu_plan_time = 0
    accu_exe_time = 0
    accu_mask_time = 0
    ####################

    for it in range(args['max_iters']):

        if (it > 0) and (backup_rate > 0) and (it % backup_rate
                                               == 0) and (data_saver
                                                          is not None):
            data_saver.save(episode_stats, ep_id=it)

        cur_infos = []
        motion.reset()
        set_seed(seed + it + 1)  # reset seed
        if plan_req is not None:
            while True:
                task.reset(target=fixed_target)
                m = len(task.get_optimal_plan())
                if (m in plan_req) and plan_req[m] > 0:
                    break
            plan_req[m] -= 1
        else:
            task.reset(target=fixed_target)
        info = task.info

        episode_success.append(0)
        episode_good.append(0)
        task_target = task.get_current_target()
        cur_stats = dict(best_dist=info['dist'],
                         success=0,
                         good=0,
                         reward=0,
                         target=task_target,
                         plan=[],
                         meters=task.info['meters'],
                         optstep=task.info['optsteps'],
                         length=max_episode_len,
                         images=None)
        if hasattr(task.house, "_id"):
            cur_stats['world_id'] = task.house._id

        store_history = args['store_history']
        if store_history:
            cur_infos.append(proc_info(task.info))

        episode_step = 0

        # reset planner
        if graph is not None:
            graph.reset()

        while episode_step < max_episode_len:
            if flag_interrupt and motion.is_interrupt():
                graph_target = task.get_current_target()
            else:
                # TODO #####################
                tt = time.time()
                mask_feat = oracle_func.get(
                    task
                ) if oracle_func is not None else task.get_feature_mask()
                accu_mask_time += time.time() - tt
                tt = time.time()
                graph_target = graph.plan(mask_feat, task_target)
                accu_plan_time += time.time() - tt
                ################################
            graph_target_id = common.target_instruction_dict[graph_target]
            allowed_steps = min(max_episode_len - episode_step,
                                max_motion_steps)

            ###############
            # TODO
            tt = time.time()
            motion_data = motion.run(graph_target, allowed_steps)
            accu_exe_time += time.time() - tt

            cur_stats['plan'].append(
                (graph_target, len(motion_data),
                 (motion_data[-1][0][graph_target_id] > 0)))

            # store stats
            for dat in motion_data:
                info = dat[4]
                if store_history:
                    cur_infos.append(proc_info(info))
                cur_dist = info['dist']
                if cur_dist == 0:
                    cur_stats['good'] += 1
                    episode_good[-1] = 1
                if cur_dist < cur_stats['best_dist']:
                    cur_stats['best_dist'] = cur_dist

            # update graph
            ## TODO ############
            tt = time.time()
            graph.observe(motion_data, graph_target)
            accu_plan_time += time.time() - tt

            episode_step += len(motion_data)

            # check done
            if motion_data[-1][3]:
                if motion_data[-1][2] > 5:  # magic number
                    episode_success[-1] = 1
                    cur_stats['success'] = 1
                break

        cur_stats['length'] = episode_step  # store length

        if store_history:
            cur_stats['infos'] = cur_infos
        episode_stats.append(cur_stats)

        dur = time.time() - elap
        logger.print('Episode#%d, Elapsed = %.3f min' % (it + 1, dur / 60))
        #TODO #################
        logger.print(' >>> Mask Time = %.4f min' % (accu_mask_time / 60))
        logger.print(' >>> Plan Time = %.4f min' % (accu_plan_time / 60))
        logger.print(' >>> Motion Time = %.4f min' % (accu_exe_time / 60))
        if args['multi_target']:
            logger.print('  ---> Target Room = {}'.format(cur_stats['target']))
        logger.print('  ---> Total Samples = {}'.format(t))
        logger.print('  ---> Success = %d  (rate = %.3f)' %
                     (cur_stats['success'], np.mean(episode_success)))
        logger.print(
            '  ---> Times of Reaching Target Room = %d  (rate = %.3f)' %
            (cur_stats['good'], np.mean(episode_good)))
        logger.print('  ---> Best Distance = %d' % cur_stats['best_dist'])
        logger.print('  ---> Birth-place Meters = %.4f (optstep = %d)' %
                     (cur_stats['meters'], cur_stats['optstep']))
        logger.print('  ---> Planner Results = {}'.format(cur_stats['plan']))

    logger.print('######## Final Stats ###########')
    logger.print('Success Rate = %.3f' % np.mean(episode_success))
    logger.print(
        '> Avg Ep-Length per Success = %.3f' %
        np.mean([s['length'] for s in episode_stats if s['success'] > 0]))
    logger.print(
        '> Avg Birth-Meters per Success = %.3f' %
        np.mean([s['meters'] for s in episode_stats if s['success'] > 0]))
    logger.print('Reaching Target Rate = %.3f' % np.mean(episode_good))
    logger.print('> Avg Ep-Length per Target Reach = %.3f' %
                 np.mean([s['length']
                          for s in episode_stats if s['good'] > 0]))
    logger.print('> Avg Birth-Meters per Target Reach = %.3f' %
                 np.mean([s['meters']
                          for s in episode_stats if s['good'] > 0]))
    if args['multi_target']:
        all_targets = list(set([s['target'] for s in episode_stats]))
        for tar in all_targets:
            n = sum([1.0 for s in episode_stats if s['target'] == tar])
            succ = [
                float(s['success'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            good = [
                float(s['good'] > 0) for s in episode_stats
                if s['target'] == tar
            ]
            length = [s['length'] for s in episode_stats if s['target'] == tar]
            meters = [s['meters'] for s in episode_stats if s['target'] == tar]
            good_len = np.mean([l for l, g in zip(length, good) if g > 0.5])
            succ_len = np.mean([l for l, s in zip(length, succ) if s > 0.5])
            good_mts = np.mean([l for l, g in zip(meters, good) if g > 0.5])
            succ_mts = np.mean([l for l, s in zip(meters, succ) if s > 0.5])
            logger.print(
                '>>>>> Multi-Target <%s>: Rate = %.3f (n=%d), Good = %.3f (AvgLen=%.3f; Mts=%.3f), Succ = %.3f (AvgLen=%.3f; Mts=%.3f)'
                % (tar, n / len(episode_stats), n, np.mean(good), good_len,
                   good_mts, np.mean(succ), succ_len, succ_mts))

    return episode_stats
コード例 #9
0
ファイル: train.py プロジェクト: qiming-zou/HouseNavAgent
def train(
        args=None,
        houseID=0,
        reward_type='indicator',
        success_measure='center',
        multi_target=False,
        include_object_target=False,
        algo='pg',
        model_name='cnn',  # NOTE: optional: model_name='rnn'
        iters=2000000,
        report_rate=20,
        save_rate=1000,
        eval_range=200,
        log_dir='./temp',
        save_dir='./_model_',
        warmstart=None,
        log_debug_info=True):

    if 'scheduler' in args:
        scheduler = args['scheduler']
    else:
        scheduler = None

    if args is None:
        args = common.create_default_args(algo)

    hardness = args['hardness']
    max_birthplace_steps = args['max_birthplace_steps']
    if hardness is not None:
        print('>>> Hardness Level = {}'.format(hardness))
    if max_birthplace_steps is not None:
        print('>>>> Max BirthPlace Steps = {}'.format(max_birthplace_steps))

    env = common.create_env(houseID,
                            task_name=args['task_name'],
                            false_rate=args['false_rate'],
                            reward_type=reward_type,
                            hardness=hardness,
                            max_birthplace_steps=max_birthplace_steps,
                            success_measure=success_measure,
                            segment_input=args['segment_input'],
                            depth_input=args['depth_input'],
                            render_device=args['render_gpu'],
                            cacheAllTarget=args['multi_target'],
                            use_discrete_action=('dpg' not in algo),
                            include_object_target=include_object_target)
    trainer = common.create_trainer(algo, model_name, args)
    logger = utils.MyLogger(log_dir, True)
    if multi_target:
        assert hasattr(trainer, 'set_target')

    if warmstart is not None:
        if os.path.exists(warmstart):
            logger.print('Warmstarting from <{}> ...'.format(warmstart))
            trainer.load(warmstart)
        else:
            logger.print(
                'Warmstarting from save_dir <{}> with version <{}> ...'.format(
                    save_dir, warmstart))
            trainer.load(save_dir, warmstart)

    logger.print('Start Training')

    if log_debug_info:
        common.debugger = utils.MyLogger(log_dir, True, 'full_logs.txt')
    else:
        common.debugger = utils.FakeLogger()

    episode_rewards = [0.0]
    episode_success = [0.0]
    episode_length = [0.0]
    episode_targets = ['kitchen']

    trainer.reset_agent()
    if multi_target:
        obs = env.reset()
        target_room = env.info['target_room']
        trainer.set_target(target_room)
        episode_targets[-1] = target_room
    else:
        env.reset(target='kitchen')
    assert not np.any(np.isnan(obs)), 'nan detected in the observation!'
    obs = obs.transpose([1, 0, 2])
    logger.print('Observation Shape = {}'.format(obs.shape))

    episode_step = 0
    t = 0
    best_res = -1e50
    elap = time.time()
    update_times = 0
    print('Starting iterations...')
    try:
        while (len(episode_rewards) <= iters):
            idx = trainer.process_observation(obs)
            # get action
            if scheduler is not None:
                noise_level = scheduler.value(len(episode_rewards) - 1)
                action = trainer.action(noise_level)
            else:
                action = trainer.action()
            #proc_action = [np.exp(a) for a in action]
            # environment step
            obs, rew, done, info = env.step(action)
            assert not np.any(
                np.isnan(obs)), 'nan detected in the observation!'
            obs = obs.transpose([1, 0, 2])
            episode_step += 1
            episode_length[-1] += 1
            terminal = (episode_step >= args['episode_len'])
            # collect experience
            trainer.process_experience(idx, action, rew, done, terminal, info)
            episode_rewards[-1] += rew
            if rew > 5:  # magic number
                episode_success[-1] = 1.0

            if done or terminal:
                trainer.reset_agent()
                if multi_target:
                    obs = env.reset()
                    target_room = env.info['target_room']
                    trainer.set_target(target_room)
                    episode_targets.append(target_room)
                else:
                    obs = env.reset(target='kitchen')
                assert not np.any(
                    np.isnan(obs)), 'nan detected in the observation!'
                obs = obs.transpose([1, 0, 2])
                episode_step = 0
                episode_rewards.append(0)
                episode_success.append(0)
                episode_length.append(0)

            # update all trainers
            trainer.preupdate()
            stats = trainer.update()
            if stats is not None:
                update_times += 1
                if common.debugger is not None:
                    common.debugger.print(
                        '>>>>>> Update#{} Finished!!!'.format(update_times),
                        False)

            # save results
            if ((done or terminal) and (len(episode_rewards) % save_rate == 0)) or\
               (len(episode_rewards) > iters):
                trainer.save(save_dir)
                logger.print(
                    'Successfully Saved to <{}>'.format(save_dir + '/' +
                                                        trainer.name + '.pkl'))
                if np.mean(episode_rewards[-eval_range:]) > best_res:
                    best_res = np.mean(episode_rewards[-eval_range:])
                    trainer.save(save_dir, "best")

            # display training output
            if ((update_times % report_rate == 0) and (algo != 'pg') and (stats is not None)) or \
                ((update_times == 0) and (algo != 'pg') and (len(episode_rewards) % 100 == 0) and (done or terminal)) or \
                ((algo == 'pg') and (stats is not None)):
                logger.print(
                    'Episode#%d, Updates=%d, Time Elapsed = %.3f min' %
                    (len(episode_rewards), update_times,
                     (time.time() - elap) / 60))
                logger.print('-> Total Samples: %d' % t)
                logger.print('-> Avg Episode Length: %.4f' %
                             (t / len(episode_rewards)))
                if stats is not None:
                    for k in stats:
                        logger.print('  >> %s = %.4f' % (k, stats[k]))
                logger.print('  >> Reward  = %.4f' %
                             np.mean(episode_rewards[-eval_range:]))
                logger.print('  >> Success Rate  = %.4f' %
                             np.mean(episode_success[-eval_range:]))
                if multi_target:
                    ep_rew = episode_rewards[-eval_range:]
                    ep_suc = episode_success[-eval_range:]
                    ep_tar = episode_targets[-eval_range:]
                    ep_len = episode_length[-eval_range:]
                    total_n = len(ep_rew)
                    tar_stats = dict()
                    for k, r, s, l in zip(ep_tar, ep_rew, ep_suc, ep_len):
                        if k not in tar_stats:
                            tar_stats[k] = [0.0, 0.0, 0.0, 0.0]
                        tar_stats[k][0] += 1
                        tar_stats[k][1] += r
                        tar_stats[k][2] += s
                        tar_stats[k][3] += l
                    for k in tar_stats.keys():
                        n, r, s, l = tar_stats[k]
                        logger.print(
                            '  --> Multi-Room<%s> Freq = %.4f, Rew = %.4f, Succ = %.4f (AvgLen = %.3f)'
                            % (k, n / total_n, r / n, s / n, l / n))
                print('----> Data Loading Time = %.4f min' %
                      (time_counter[-1] / 60))
                print('----> GPU Data Transfer Time = %.4f min' %
                      (time_counter[0] / 60))
                print('----> Training Time = %.4f min' %
                      (time_counter[1] / 60))
                print('----> Target Net Update Time = %.4f min' %
                      (time_counter[2] / 60))

            t += 1
    except KeyboardInterrupt:
        print('Keyboard Interrupt!!!!!!')
    trainer.save(save_dir, "final")
    with open(save_dir + '/final_training_stats.pkl', 'wb') as f:
        pickle.dump([
            episode_rewards, episode_success, episode_targets, episode_length
        ], f)
コード例 #10
0
def run(config):
    def len_parallelloader(self):
        return len(self._loader._loader)
    pl.PerDeviceLoader.__len__ = len_parallelloader

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        xm.master_print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different
    # files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    xm.master_print('Experiment name is %s' % experiment_name)

    device = xm.xla_device(devkind='TPU')

    # Next, build the model
    G = model.Generator(**config)
    D = model.Discriminator(**config)

    # If using EMA, prepare it
    if config['ema']:
        xm.master_print(
            'Preparing EMA for G with decay of {}'.format(
                config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True})
    else:
        xm.master_print('Not using ema...')
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        xm.master_print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        xm.master_print('Casting D to fp16...')
        D = D.half()

    # Prepare state dict, which holds things like itr #
    state_dict = {'itr': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        xm.master_print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # move everything to TPU
    G = G.to(device)
    D = D.to(device)

    G.optim = optim.Adam(params=G.parameters(), lr=G.lr,
                         betas=(G.B1, G.B2), weight_decay=0,
                         eps=G.adam_eps)
    D.optim = optim.Adam(params=D.parameters(), lr=D.lr,
                         betas=(D.B1, D.B2), weight_decay=0,
                         eps=D.adam_eps)

    # for key, val in G.optim.state.items():
    #  G.optim.state[key]['exp_avg'] = G.optim.state[key]['exp_avg'].to(device)
    #  G.optim.state[key]['exp_avg_sq'] = G.optim.state[key]['exp_avg_sq'].to(device)

    # for key, val in D.optim.state.items():
    #  D.optim.state[key]['exp_avg'] = D.optim.state[key]['exp_avg'].to(device)
    #  D.optim.state[key]['exp_avg_sq'] = D.optim.state[key]['exp_avg_sq'].to(device)

    if config['ema']:
        G_ema = G_ema.to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])

    # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    xm.master_print(G)
    xm.master_print(D)
    xm.master_print('Number of params in G: {} D: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]))

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    xm.master_print(
        'Test Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    xm.master_print(
        'Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])

    if xm.is_master_ordinal():
            # Write metadata
        utils.write_metadata(
            config['logs_root'],
            experiment_name,
            config,
            state_dict)

    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    xm.master_print('Preparing data...')
    loader = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                       'start_itr': state_dict['itr']})

    # Prepare inception metrics: FID and IS
    xm.master_print('Preparing metrics...')

    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'],
        no_inception=config['no_inception'],
        no_fid=config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])

    def sample(): return utils.prepare_z_y(G_batch_size, G.dim_z,
                                           config['n_classes'], device=device,
                                           fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout
    # training
    fixed_z, fixed_y = sample()

    train = train_fns.GAN_training_function(G, D, GD, sample, ema, state_dict,
                                            config)

    xm.master_print('Beginning training...')

    if xm.is_master_ordinal():
        pbar = tqdm(total=config['total_steps'])
        pbar.n = state_dict['itr']
        pbar.refresh()

    xm.rendezvous('training_starts')
    while (state_dict['itr'] < config['total_steps']):
        pl_loader = pl.ParallelLoader(
            loader, [device]).per_device_loader(device)

        for i, (x, y) in enumerate(pl_loader):
            if xm.is_master_ordinal():
                # Increment the iteration counter
                pbar.update(1)

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter
            # much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()

            xm.rendezvous('data_collection')
            metrics = train(x, y)

            # train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if ((config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval']))) :
                if xm.is_master_ordinal():
                    train_log.log(itr=int(state_dict['itr']),
                        **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
                xm.rendezvous('Log SVs.')

            # Save weights and copies as configured at specified interval
            if (not (state_dict['itr'] % config['save_every'])):
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G,
                    D,
                    G_ema,
                    sample,
                    fixed_z,
                    fixed_y,
                    state_dict,
                    config,
                    experiment_name)

            # Test every specified interval
            if (not (state_dict['itr'] % config['test_every'])):

                which_G = G_ema if config['ema'] and config['use_ema'] else G
                if config['G_eval_mode']:
                    xm.master_print('Switchin G to eval mode...')
                    which_G.eval()

                def G_sample():
                    z, y = sample()
                    return which_G(z, which_G.shared(y))

                train_fns.test(
                    G,
                    D,
                    G_ema,
                    sample,
                    state_dict,
                    config,
                    G_sample,
                    get_inception_metrics,
                    experiment_name,
                    test_log)
            
            # Debug : Message print
            # if True:
            #     xm.master_print(met.metrics_report())

            if state_dict['itr'] >= config['total_steps']:
                break
コード例 #11
0
def run(config):
    logger = logging.getLogger('tl')
    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = importlib.import_module(config['model'])
    # model = __import__(config['model'])
    experiment_name = 'exp'
    # experiment_name = (config['experiment_name'] if config['experiment_name']
    #                      else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config, cfg=getattr(global_cfg, 'generator',
                                              None)).to(device)
    D = model.Discriminator(**config,
                            cfg=getattr(global_cfg, 'discriminator',
                                        None)).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        },
                                cfg=getattr(global_cfg, 'generator',
                                            None)).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D)
    logger.info(G)
    logger.info(D)
    logger.info('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G=G,
                           D=D,
                           state_dict=state_dict,
                           weights_root=global_cfg.resume_cfg.weights_root,
                           experiment_name='',
                           name_suffix=config['load_weights']
                           if config['load_weights'] else None,
                           G_ema=G_ema if config['ema'] else None)
        logger.info(f"Resume IS={state_dict['best_IS']}")
        logger.info(f"Resume FID={state_dict['best_FID']}")

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr'],
            **getattr(global_cfg, 'train_dataloader', {})
        })

    val_loaders = None
    if hasattr(global_cfg, 'val_dataloader'):
        val_loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                **global_cfg.val_dataloader
            })[0]
        val_loaders = iter(val_loaders)
    # Prepare inception metrics: FID and IS
    if global_cfg.get('use_unofficial_FID', False):
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['inception_file'], config['parallel'], config['no_fid'])
    else:
        get_inception_metrics = inception_utils.prepare_FID_IS(global_cfg)
    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config,
                                                val_loaders)
    # Else, assume debugging and use the dummy train fn
    elif config['which_train_fn'] == 'dummy':
        train = train_fns.dummy_training_function()
    else:
        train_fns_module = importlib.import_module(config['which_train_fn'])
        train = train_fns_module.GAN_training_function(G, D, GD, z_, y_, ema,
                                                       state_dict, config,
                                                       val_loaders)

    # Prepare Sample function for use with inception metrics
    if global_cfg.get('use_unofficial_FID', False):
        sample = functools.partial(
            utils.sample,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)
    else:
        sample = functools.partial(
            utils.sample_imgs,
            G=(G_ema if config['ema'] and config['use_ema'] else G),
            z_=z_,
            y_=y_,
            config=config)

    state_dict['shown_images'] = state_dict['itr'] * D_batch_size

    if global_cfg.get('resume_cfg', {}).get('eval', False):
        logger.info(f'Evaluating model.')
        G_ema.eval()
        G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
        return

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  desc=f'Epoch:{epoch}, Itr: ',
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)

            default_dict = train(x, y)

            state_dict['shown_images'] += D_batch_size

            metrics = default_dict['D_loss']
            train_log.log(itr=int(state_dict['itr']), **metrics)

            summary_defaultdict2txtfig(default_dict=default_dict,
                                       prefix='train',
                                       step=state_dict['shown_images'],
                                       textlogger=textlogger)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ',
                      flush=True)

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

            # Test every specified interval
            if state_dict['itr'] == 1 or \
                  (config['test_every'] > 0 and state_dict['itr'] % config['test_every'] == 0) or \
                  (state_dict['shown_images'] % global_cfg.get('test_every_images', float('inf'))) < D_batch_size:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...', flush=True)
                    G.eval()
                print('\n' + config['tl_outdir'])
                train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                               get_inception_metrics, experiment_name,
                               test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
コード例 #12
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    D = model.Discriminator(**config).to(device)

    # FP16?
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?

    print(D)

    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'config': config}

    # If parallel, parallelize the GD module
    if config['parallel']:
        D = nn.DataParallel(D)

        if config['cross_replica']:
            patch_replication_callback(D)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # set tensorboard logger
    tb_logdir = '%s/%s/tblogs' % (config['logs_root'], experiment_name)
    if os.path.exists(tb_logdir):
        for filename in os.listdir(tb_logdir):
            if filename.startswith('events'):
                os.remove(os.path.join(tb_logdir,
                                       filename))  # remove previous event logs
    tb_writer = SummaryWriter(log_dir=tb_logdir)
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'MINE':
        train = train_fns.MINE_training_function(D, state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()

    print('Beginning training at epoch %d...' % state_dict['epoch'])

    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own (mine, ok)?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            D.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)
            print(metrics)
            train_log.log(itr=int(state_dict['itr']), **metrics)
            for metric_name in metrics:
                tb_writer.add_scalar('Train/%s' % metric_name,
                                     metrics[metric_name], state_dict['itr'])

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
コード例 #13
0
def learn_controller(args):

    elap = time.time()

    # Do not need to log detailed computation stats
    common.debugger = utils.FakeLogger()

    if args['object_target']:
        common.ensure_object_targets()

    set_seed(args['seed'])
    task = common.create_env(args['house'],
                             task_name=args['task_name'],
                             false_rate=args['false_rate'],
                             success_measure=args['success_measure'],
                             depth_input=args['depth_input'],
                             target_mask_input=args['target_mask_input'],
                             segment_input=args['segmentation_input'],
                             cacheAllTarget=True,
                             render_device=args['render_gpu'],
                             use_discrete_action=True,
                             include_object_target=args['object_target'],
                             include_outdoor_target=args['outdoor_target'],
                             discrete_angle=True)

    # create motion
    __controller_warmstart = args['warmstart']
    args['warmstart'] = args['motion_warmstart']
    motion = create_motion(args, task)
    args['warmstart'] = __controller_warmstart

    # logger
    logger = utils.MyLogger(args['save_dir'], True)

    logger.print("> Planner Units = {}".format(args['units']))
    logger.print("> Max Planner Steps = {}".format(args['max_planner_steps']))
    logger.print("> Max Exploration Steps = {}".format(args['max_exp_steps']))
    logger.print("> Reward = {} & {}".format(args['time_penalty'],
                                             args['success_reward']))

    # Planner Learning
    logger.print('Start RNN Planner Learning ...')

    planner = RNNPlanner(motion, args['units'], args['warmstart'])

    fixed_target = None
    if args['only_eval_room']:
        fixed_target = 'any-room'
    elif args['only_eval_object']:
        fixed_target = 'any-object'
    train_stats, eval_stats = \
        planner.learn(args['iters'], args['max_episode_len'],
                      target=fixed_target,
                      motion_steps=args['max_exp_steps'],
                      planner_steps=args['max_planner_steps'],
                      batch_size=args['batch_size'],
                      lrate=args['lrate'], grad_clip=args['grad_clip'],
                      weight_decay=args['weight_decay'], gamma=args['gamma'],
                      entropy_penalty=args['entropy_penalty'],
                      save_dir=args['save_dir'],
                      report_rate=5, eval_rate=20, save_rate=100,
                      logger=logger, seed=args['seed'])

    logger.print('######## Done ###########')
    filename = args['save_dir']
    if filename[-1] != '/': filename = filename + '/'
    filename = filename + 'train_stats.pkl'
    with open(filename, 'wb') as f:
        pickle.dump([train_stats, eval_stats], f)
    logger.print('  --> Training Stats Saved to <{}>!'.format(filename))
    return planner
コード例 #14
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    E = model.ImgEncoder(**config).to(device)
    # E = model.Encoder(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GDE = model.G_D_E(G, D, E)

    print('Number of params in G: {} D: {} E: {}'.format(*[
        sum([p.data.nelement() for p in net.parameters()])
        for net in [G, D, E]
    ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, E, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GDE = nn.DataParallel(GDE)
        if config['cross_replica']:
            patch_replication_callback(GDE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders, train_dataset = utils.get_data_loaders(
        **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # # Prepare inception metrics: FID and IS
    # get_inception_metrics = inception_utils.prepare_inception_metrics(
    #     config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    print("fixed_y original: {} {}".format(fixed_y.shape, fixed_y[:10]))
    ## TODO: change the sample method to sample x and y
    fixed_x, fixed_y_of_x = utils.prepare_x_y(G_batch_size, train_dataset,
                                              experiment_name, config)

    # Build image pool to prevent mode collapes
    if config['img_pool_size'] != 0:
        img_pool = ImagePool(config['img_pool_size'], train_dataset.num_class,\
                                    save_dir=os.path.join(config['imgbuffer_root'], experiment_name),
                                    resume_buffer=config['resume_buffer'])
    else:
        img_pool = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, E, GDE, ema, state_dict,
                                                config, img_pool)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    # print('Beginning training at epoch %f...' % (state_dict['itr'] * D_batch_size / len(train_dataset)))
    print("Beginning training at Epoch {} (iteration {})".format(
        state_dict['epoch'], state_dict['itr']))
    # # Train for specified number of epochs, although we mostly track G iterations.
    # for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
        pbar = utils.progress(
            loaders[0],
            displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
        pbar = tqdm(loaders[0])

    for i, (x, y) in enumerate(pbar):
        # Increment the iteration counter
        state_dict['itr'] += 1
        # Make sure G and D are in training mode, just in case they got set to eval
        # For D, which typically doesn't have BN, this shouldn't matter much.
        G.eval()
        D.eval()
        if config['ema']:
            G_ema.eval()
        if config['D_fp16']:
            x, y = x.to(device).half(), y.to(device)
        else:
            x, y = x.to(device), y.to(device)

        # Every sv_log_interval, log singular values
        if (config['sv_log_interval'] >
                0) and (not (state_dict['itr'] % config['sv_log_interval'])):
            train_log.log(itr=int(state_dict['itr']),
                          **{
                              **utils.get_SVs(G, 'G'),
                              **utils.get_SVs(D, 'D')
                          })

        # If using my progbar, print metrics.
        if config['pbar'] == 'mine':
            print(', '.join(
                ['itr: %d' % state_dict['itr']] +
                ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                  end=' ')

        # Save weights and copies as configured at specified interval
        if (not state_dict['itr'] % config['save_img_every']) or (
                not state_dict['itr'] % config['save_model_every']):
            if config['G_eval_mode']:
                print('Switchin G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            save_weights = config['save_weights']
            if state_dict['itr'] % config['save_model_every']:
                save_weights = False
            train_fns.save_and_sample(G,
                                      D,
                                      E,
                                      G_ema,
                                      fixed_x,
                                      fixed_y_of_x,
                                      z_,
                                      y_,
                                      state_dict,
                                      config,
                                      experiment_name,
                                      img_pool,
                                      save_weights=save_weights)

        # # Test every specified interval
        # if not (state_dict['itr'] % config['test_every']):
        #     if config['G_eval_mode']:
        #         print('Switchin G to eval mode...')
        #         G.eval()
        #     train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
        #                    get_inception_metrics, experiment_name, test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] = state_dict['itr'] * D_batch_size / (
            len(train_dataset))
        print("Finished Epoch {} (iteration {})".format(
            state_dict['epoch'], state_dict['itr']))
コード例 #15
0
ファイル: train.py プロジェクト: duxiaodan/BigVidGAN
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    num_devices = torch.cuda.device_count()
    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.ImageDiscriminator(**config).to(device)
    if config['no_Dv'] == False:
        Dv = model.VideoDiscriminator(**config).to(device)
    else:
        Dv = None

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        if config['no_Dv'] == False:
            Dv = Dv.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D, Dv, config['k'],
        config['T_into_B'])  #xiaodan: add an argument k and T_into_B
    # print('GD.k in train.py line 91',GD.k)
    # print(G) # xiaodan: print disabled by xiaodan. Too many stuffs
    # print(D)
    if config['no_Dv'] == False:
        print('Number of params in G: {} D: {} Dv: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, Dv]
        ]))
    else:
        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # If loading from a pre-trained BigGAN model, load weights
    if config['biggan_init']:
        print('Loading weights from pre-trained BigGAN...')
        utils.load_biggan_weights(G,
                                  D,
                                  state_dict,
                                  config['biggan_weights_root'],
                                  G_ema if config['ema'] else None,
                                  load_optim=False)

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, Dv, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    if config['dataset'] == 'C10':
        loaders = utils.get_video_cifar_data_loader(
            **{
                **config, 'batch_size': D_batch_size,
                'start_itr': state_dict['itr']
            })
    else:
        loaders = utils.get_video_data_loaders(**{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })
    # print(loaders)
    # print(loaders[0])
    print('D loss weight:', config['D_loss_weight'])
    # Prepare inception metrics: FID and IS
    if config['skip_testing'] == False:
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(
        config['G_batch_size'], config['batch_size']
    )  # * num_devices #xiaodan: num_devices added by xiaodan
    # print('num_devices:',num_devices,'G_batch_size:',G_batch_size)
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'])
    # print('z_,y_ shapes after prepare_z_y:',z_.shape,y_.shape)
    # print('z_,y_ size:',z_.shape,y_.shape)
    # print('G.dim_z:',G.dim_z)
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, Dv, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    unique_id = datetime.datetime.now().strftime('%Y%m-%d%H-%M%S-')
    tensorboard_path = os.path.join(config['logs_root'], 'tensorboard_logs',
                                    unique_id)
    os.makedirs(tensorboard_path)
    # Train for specified number of epochs, although we mostly track G iterations.
    writer = SummaryWriter(log_dir=tensorboard_path)
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        iteration = epoch * len(pbar)
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter

            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['no_Dv'] == False:
                Dv.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y, writer, iteration + i)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                if config['no_Dv'] == False:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D'),
                                      **utils.get_SVs(Dv, 'Dv')
                                  })
                else:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{
                                      **utils.get_SVs(G, 'G'),
                                      **utils.get_SVs(D, 'D')
                                  })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, Dv, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)
            #xiaodan: Disabled test for now because we don't have inception data
            # Test every specified interval
            if not (state_dict['itr'] %
                    config['test_every']) and config['skip_testing'] == False:
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                IS_mean, IS_std, FID = train_fns.test(
                    G, D, Dv, G_ema, z_, y_, state_dict, config, sample,
                    get_inception_metrics, experiment_name, test_log)
                writer.add_scalar('Inception/IS', IS_mean, iteration + i)
                writer.add_scalar('Inception/IS_std', IS_std, iteration + i)
                writer.add_scalar('Inception/FID', FID, iteration + i)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1
コード例 #16
0
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[
        config['dataset']] * config['cluster_per_class']
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['is_encoder']:
        config['E_fp16'] = float(config['D_fp16'])
        config['num_E_accumulations'] = int(config['num_D_accumulations'])
        config['dataset_channel'] = utils.channel_dict[config['dataset']]
        config['lambda_encoder'] = config['resolution']**2 * config[
            'dataset_channel']

    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    if config['is_encoder']:
        E = model.Encoder(**{**config, 'D': D}).to(device)
    Prior = layers.Prior(**config).to(device)
    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
        if not config['prior_type'] == 'default':
            Prior = Prior.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    if config['is_encoder'] and config['E_fp16']:
        print('Casting E to fp16...')
        E = E.half()

    print(G)
    print(D)
    if config['is_encoder']:
        print(E)
    print(Prior)
    if not config['is_encoder']:
        GD = model.G_D(G, D)
        print('Number of params in G: {} D: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D]
        ]))
    else:
        GD = model.G_D(G, D, E, Prior)
        GE = model.G_E(G, E, Prior)
        print('Number of params in G: {} D: {} E: {}'.format(*[
            sum([p.data.nelement() for p in net.parameters()])
            for net in [G, D, E]
        ]))

    # Prepare state dict, which holds things like epoch # and itr #
    # ¡¡¡¡¡¡¡¡¡ Put rec error, discriminator loss and generator loss !!!!!!!!!!!?????????
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_error_rec': 99999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G,
            D,
            state_dict,
            config['weights_root'],
            experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None,
            E=None if not config['is_encoder'] else E,
            Prior=Prior if not config['prior_type'] == 'default' else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # If parallel, parallelize the GD module
    #if config['parallel'] and config['is_encoder']:
    #  GE = nn.DataParallel(GE)
    #  if config['cross_replica']:
    #    patch_replication_callback(GE)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{
        **config, 'batch_size': D_batch_size,
        'start_itr': state_dict['itr']
    })
    if config['is_encoder']:
        config_aux = config.copy()
        config_aux['augment'] = False
        dataloader_noaug = utils.get_data_loaders(
            **{
                **config_aux, 'batch_size': D_batch_size,
                'start_itr': state_dict['itr']
            })

    # Prepare inception metrics: FID and IS
    if (config['dataset'] in ['C10']):
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
    else:
        get_inception_metrics = None

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(
            G, D, GD, Prior, ema, state_dict, config,
            losses.Loss_obj(**config), None if not config['is_encoder'] else E)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        Prior=Prior,
        config=config)

    # Create fixed
    fixed_z, fixed_y = Prior.sample_noise_and_y()
    fixed_z, fixed_y = fixed_z.clone(), fixed_y.clone()
    iter_num = 0
    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['is_encoder']:
                E.train()
            if not config['prior_type'] == 'default':
                Prior.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)

            metrics = train(x, y, iter_num)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })
                if config['is_encoder']:
                    train_log.log(itr=int(state_dict['itr']),
                                  **{**utils.get_SVs(E, 'E')})

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if not config['prior_type'] == 'default':
                        Prior.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(
                    G, D, G_ema, Prior, fixed_z, fixed_y, state_dict, config,
                    experiment_name, None if not config['is_encoder'] else E)

            if not (state_dict['itr'] %
                    config['test_every']) and config['is_encoder']:
                if not config['prior_type'] == 'default':
                    test_acc, test_acc_iter, error_rec = train_fns.test_accuracy(
                        GE, dataloader_noaug, device, config['D_fp16'], config)
                    p_mse, p_lik = train_fns.test_p_acc(GE, device, config)
                if config['n_classes'] == 10:
                    utils.reconstruction_sheet(
                        GE,
                        classes_per_sheet=utils.classes_per_sheet_dict[
                            config['dataset']],
                        num_classes=config['n_classes'],
                        samples_per_class=20,
                        parallel=config['parallel'],
                        samples_root=config['samples_root'],
                        experiment_name=experiment_name,
                        folder_number=state_dict['itr'],
                        dataloader=dataloader_noaug,
                        device=device,
                        D_fp16=config['D_fp16'],
                        config=config)

            # Test every specified interval
            if not (state_dict['itr'] % config['test_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    if not config['prior_type'] == 'default':
                        Prior.eval()
                    G.eval()
                train_fns.test(
                    G, D, G_ema, Prior, state_dict, config, sample,
                    get_inception_metrics, experiment_name, test_log,
                    None if not config['is_encoder'] else E,
                    None if config['prior_type'] == 'default' else
                    (test_acc, test_acc_iter, error_rec, p_mse, p_lik))

            if not (state_dict['itr'] % config['test_every']):
                utils.create_curves(train_metrics_fname,
                                    plot_sv=False,
                                    prior_type=config['prior_type'],
                                    is_E=config['is_encoder'])
                utils.plot_IS_FID(train_metrics_fname)

        # Increment epoch counter at end of epoch
        iter_num += 1
        state_dict['epoch'] += 1
コード例 #17
0
ファイル: train.py プロジェクト: Chandanpanda/kaggle-2
def run(config):

  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True
  config = utils.update_config_roots(config)
  device = 'cuda'
  if config['base_root']:
    os.makedirs(config['base_root'],exist_ok=True)

  # Seed RNG
  utils.seed_rng(config['seed'])

  # Prepare root folders if necessary
  utils.prepare_root(config)

  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  
  # FP16?
  if config['G_fp16']:
    print('Casting G to float16...')
    G = G.half()
    if config['ema']:
      G_ema = G_ema.half()
  if config['D_fp16']:
    print('Casting D to fp16...')
    D = D.half()
    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None,
                       )
    if G.lr_sched is not None:G.lr_sched.step(state_dict['epoch'])
    if D.lr_sched is not None:D.lr_sched.step(state_dict['epoch'])

  # If parallel, parallelize the GD module
  if config['parallel']:
    GD = nn.DataParallel(GD)
    if config['cross_replica']:
      patch_replication_callback(GD)

  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  test_log = utils.MetricsLogger(test_metrics_fname, 
                                 reinitialize=(not config['resume']))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  if not config['on_kaggle']:
    get_inception_metrics = inception_utils.prepare_inception_metrics(config['base_root'],config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  if config['use_dog_cnt']:
    y_dist='categorical_dog_cnt'
  else:
    y_dist = 'categorical'

  dim_z=G.dim_z*2 if config['mix_style'] else G.dim_z
  z_, y_ = utils.prepare_z_y(G_batch_size, dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'],z_dist=config['z_dist'],
                             threshold=config['truncated_threshold'],y_dist=y_dist)
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'],z_dist=config['z_dist'],
                                       threshold=config['truncated_threshold'],y_dist=y_dist)
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  #I find by epoch is more convelient,so I suggest change to it.if save_every<100,I will change to py epoch
  by_epoch=False if config['save_every']>100 else True


  # Train for specified number of epochs, although we mostly track G iterations.
  start_time = time.time()
  for epoch in range(state_dict['epoch'], config['num_epochs']):
    # Which progressbar to use? TQDM or my own?
    if config['on_kaggle']:
      pbar = loaders[0]
    elif config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    epoch_start_time = time.time()
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      if config['ema']:
        G_ema.train()
      if type(y) == list or type(y)==tuple:
        y=torch.cat([yi.unsqueeze(1) for yi in y],dim=1)

      if config['D_fp16']:
        x, y = x.to(device).half(), y.to(device)
      else:
        x, y = x.to(device), y.to(device)
      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)
      
      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['on_kaggle']:
        if i == len(loaders[0])-1:
          metrics_str = ', '.join(['%s : %+4.3f' % (key, metrics[key]) for key in metrics])
          epoch_time = (time.time()-epoch_start_time) / 60
          total_time = (time.time()-start_time) / 60
          print(f"[{epoch+1}/{config['num_epochs']}][{epoch_time:.1f}min/{total_time:.1f}min] {metrics_str}")
      elif config['pbar'] == 'mine':
        if D.lr_sched is None:
          print(', '.join(['epoch:%d' % (epoch+1),'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
        else:
          print(', '.join(['epoch:%d' % (epoch+1),'lr:%.5f' % D.lr_sched.get_lr()[0] ,'itr: %d' % state_dict['itr']]
                       + ['%s : %+4.3f' % (key, metrics[key])
                       for key in metrics]), end=' ')
      if not by_epoch:
        # Save weights and copies as configured at specified interval
        if not (state_dict['itr'] % config['save_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
            if config['ema']:
              G_ema.eval()
          train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                    state_dict, config, experiment_name)

        # Test every specified interval
        if not (state_dict['itr'] % config['test_every']) and not config['on_kaggle']:
          if config['G_eval_mode']:
            print('Switchin G to eval mode...')
            G.eval()
          train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                         get_inception_metrics, experiment_name, test_log)

    if by_epoch:
      # Save weights and copies as configured at specified interval
      if not ((epoch+1) % config['save_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not ((epoch+1) % config['test_every']) and not config['on_kaggle']:
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)

      if G_ema is not None and (epoch+1) % config['test_every'] == 0 and not config['on_kaggle']:
        torch.save(G_ema.state_dict(),  '%s/%s/G_ema_epoch_%03d.pth' %
                   (config['weights_root'], config['experiment_name'], epoch+1))
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
    if G.lr_sched is not None:
      G.lr_sched.step()
    if D.lr_sched is not None:
      D.lr_sched.step()
  if config['on_kaggle']:
    train_fns.generate_submission(sample, config, experiment_name)
コード例 #18
0
def run(config):
  # Update the config dict as necessary
  # This is for convenience, to add settings derived from the user-specified
  # configuration into the config-dict (e.g. inferring the number of classes
  # and size of the images from the dataset, passing in a pytorch object
  # for the activation specified as a string)
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]



  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  # By default, skip init if resuming training.
  if config['resume']:
    print('Skipping initialization for training resumption...')
    config['skip_init'] = True



  config = utils.update_config_roots(config)
  device = 'cuda'
  
  # Seed RNG
  utils.seed_rng(config['seed'])


  # Prepare root folders if necessary
  utils.prepare_root(config)



  # Import the model--this line allows us to dynamically select different files.
  model = __import__(config['model'])
  experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
  print('Experiment name is %s' % experiment_name)

  # Next, build the model
  G = model.Generator(**config).to(device)
  D = model.Discriminator(**config).to(device)
  
   # If using EMA, prepare it
  if config['ema']:
    print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
    G_ema = model.Generator(**{**config, 'skip_init':True, 
                               'no_optim': True}).to(device)
    ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
  else:
    G_ema, ema = None, None
  

    # Consider automatically reducing SN_eps?
  GD = model.G_D(G, D)
  print(G)
  print(D)
  print('Number of params in G: {} D: {}'.format(
    *[sum([np.prod(p.shape) for p in net.parameters()]) for net in [G,D]]))
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # If loading from a pre-trained model, load weights
  if config['resume']:
    print('Loading weights...')
    utils.load_weights(G, D, state_dict,
                       config['weights_root'], experiment_name, 
                       config['load_weights'] if config['load_weights'] else None,
                       G_ema if config['ema'] else None)



  # Prepare loggers for stats; metrics holds test metrics,
  # lmetrics holds any desired training metrics.
  test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                            experiment_name)
  train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
  print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
  #test_log = utils.MetricsLogger(test_metrics_fname, 
  #                               reinitialize=(not config['resume']))
  test_log=LogWriter(logdir='%s/%s_log' % (config['logs_root'],
                                            experiment_name))
  print('Training Metrics will be saved to {}'.format(train_metrics_fname))
  train_log = utils.MyLogger(train_metrics_fname, 
                             reinitialize=(not config['resume']),
                             logstyle=config['logstyle'])
  # Write metadata
  utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
  # Prepare data; the Discriminator's batch size is all that needs to be passed
  # to the dataloader, as G doesn't require dataloading.
  # Note that at every loader iteration we pass in enough data to complete
  # a full D iteration (regardless of number of D steps and accumulations)
  D_batch_size = (config['batch_size'] * config['num_D_steps']
                  * config['num_D_accumulations'])
  loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                      'start_itr': state_dict['itr']})

  # Prepare inception metrics: FID and IS
  get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])

  # Prepare noise and randomly sampled label arrays
  # Allow for different batch sizes in G
  G_batch_size = max(config['G_batch_size'], config['batch_size'])
  z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                             device=device, fp16=config['G_fp16'])
  # Prepare a fixed z & y to see individual sample evolution throghout training
  fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                       config['n_classes'], device=device,
                                       fp16=config['G_fp16'])  
  fixed_z.sample_()
  fixed_y.sample_()
  # Loaders are loaded, prepare the training function
  if config['which_train_fn'] == 'GAN':
    train = train_fns.GAN_training_function(G, D, GD, z_, y_, 
                                            ema, state_dict, config)
  # Else, assume debugging and use the dummy train fn
  else:
    train = train_fns.dummy_training_function()
  # Prepare Sample function for use with inception metrics
  sample = functools.partial(utils.sample,
                              G=(G_ema if config['ema'] and config['use_ema']
                                 else G),
                              z_=z_, y_=y_, config=config)

  print('Beginning training at epoch %d...' % state_dict['epoch'])
  # Train for specified number of epochs, although we mostly track G iterations.
  for epoch in range(state_dict['epoch'], config['num_epochs']):    
    # Which progressbar to use? TQDM or my own?
    if config['pbar'] == 'mine':
      pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
    else:
      pbar = tqdm(loaders[0])
    for i, (x, y) in enumerate(pbar):
      # Increment the iteration counter
      state_dict['itr'] += 1
      # Make sure G and D are in training mode, just in case they got set to eval
      # For D, which typically doesn't have BN, this shouldn't matter much.
      G.train()
      D.train()
      x, y=x, y.astype(np.int64) ## special handling for paddle dataloader
      if config['ema']:
        G_ema.train()

      metrics = train(x, y)
      train_log.log(itr=int(state_dict['itr']), **metrics)

      for tag in metrics:
        try:
          test_log.add_scalar(step=int(state_dict['itr']),tag="train/"+tag,value=float(metrics[tag]))
        except:
          pass

      # Every sv_log_interval, log singular values
      if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
        train_log.log(itr=int(state_dict['itr']), 
                      **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

      # If using my progbar, print metrics.
      if config['pbar'] == 'mine':
          print(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]), end=' ')
      else:
          pbar.set_description(', '.join(['itr: %d' % state_dict['itr']] 
                           + ['%s : %+4.3f' % (key, metrics[key])
                           for key in metrics]))

      # Save weights and copies as configured at specified interval
      if not (state_dict['itr'] % config['save_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
          if config['ema']:
            G_ema.eval()
        train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 
                                  state_dict, config, experiment_name)

      # Test every specified interval
      if not (state_dict['itr'] % config['test_every']):
        if config['G_eval_mode']:
          print('Switchin G to eval mode...')
          G.eval()
        train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                       get_inception_metrics, experiment_name, test_log)
    # Increment epoch counter at end of epoch
    state_dict['epoch'] += 1
コード例 #19
0
ファイル: train.py プロジェクト: fagan2888/fairgen
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    # config['n_classes'] = utils.nclass_dict[config['dataset']]

    # NOTE: setting n_classes to 1 except in conditional case to train as unconditional model
    config['n_classes'] = 1
    if config['conditional']:
        config['n_classes'] = 2
    print('n classes: {}'.format(config['n_classes']))
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{
            **config, 'skip_init': True,
            'no_optim': True
        }).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(
        G, D,
        config['conditional'])  # check if labels are 0's if "unconditional"
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *
        [sum([p.data.nelement() for p in net.parameters()])
         for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num_fair': 0,
        'save_best_num_fid': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'best_fair_d': 999999,
        'config': config
    }

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(
            G, D, state_dict, config['weights_root'], experiment_name,
            config['load_weights'] if config['load_weights'] else None,
            G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.json' % (config['logs_root'],
                                             experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'], experiment_name, config,
                         state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] *
                    config['num_D_accumulations'])
    loaders = utils.get_data_loaders(
        config, **{
            **config, 'batch_size': D_batch_size,
            'start_itr': state_dict['itr']
        })

    # Prepare inception metrics: FID and IS
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               true_prop=config['true_prop'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size,
                                         G.dim_z,
                                         config['n_classes'],
                                         device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()

    # NOTE: "unconditional" GAN
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()

    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema,
                                                state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(
        utils.sample,
        G=(G_ema if config['ema'] and config['use_ema'] else G),
        z_=z_,
        y_=y_,
        config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0],
                                  displaytype='s1k' if
                                  config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])

        # iterate through the dataloaders
        for i, (x, y, ratio) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y, ratio = x.to(device).half(), y.to(device), ratio.to(
                    device)
            else:
                x, y, ratio = x.to(device), y.to(device), ratio.to(device)
            metrics = train(x, y, ratio)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (
                    not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{
                                  **utils.get_SVs(G, 'G'),
                                  **utils.get_SVs(D, 'D')
                              })

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(
                    ['itr: %d' % state_dict['itr']] +
                    ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]),
                      end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z,
                                          fixed_y, state_dict, config,
                                          experiment_name)

        # Test every epoch (not specified interval)
        if (epoch >= config['start_eval']):
            # First, find correct inception moments
            data_moments = '../../fid_stats/unbiased_all_gender_fid_stats.npz'
            if config['multi']:
                data_moments = '../../fid_stats/unbiased_all_multi_fid_stats.npz'
                fid_type = 'multi'
            else:
                fid_type = 'gender'

            # load appropriate moments
            print('Loaded data moments at: {}'.format(data_moments))
            experiment_name = (config['experiment_name']
                               if config['experiment_name'] else
                               utils.name_from_config(config))

            # eval mode for FID computation
            if config['G_eval_mode']:
                print('Switching G to eval mode...')
                G.eval()
                if config['ema']:
                    G_ema.eval()
            utils.sample_inception(
                G_ema if config['ema'] and config['use_ema'] else G, config,
                str(epoch))
            # Get saved sample path
            folder_number = str(epoch)
            sample_moments = '%s/%s/%s/samples.npz' % (
                config['samples_root'], experiment_name, folder_number)
            # Calculate FID
            FID = fid_score.calculate_fid_given_paths(
                [data_moments, sample_moments],
                batch_size=100,
                cuda=True,
                dims=2048)
            print("FID calculated")
            train_fns.update_FID(G, D, G_ema, state_dict, config, FID,
                                 experiment_name, test_log,
                                 epoch)  # added epoch logging
        # Increment epoch counter at end of epoch
        print('Completed epoch {}'.format(epoch))
        state_dict['epoch'] += 1