예제 #1
0
    plt.clf()

    if config.gen:
        z = to_data(images['ZY'])
        fig, ax = plt.subplots()
        scatter_ax(ax, x=z, y=x, fx=y, c_x='m', c_y='g', c_l='0.5', data_range=data_range)
        plt.savefig(os.path.join(dir, 'gz_%06d.png' % (step)), bbox_inches='tight')
        plt.clf()
    return

gen = 1

config = Options().parse()
config.batch_size = DISPLAY_NUM    # For plotting purposes
config.gen = gen
utils.print_opts(config)

config.solver = 'bary_ot'
#config.solver = 'w1'
#config.solver = 'w2'
#plot_dataset = 'our_checkerboard'
plot_dataset = '8gaussians'
config.data = plot_dataset
config.trial = 3

dir_string = './{0}_{1}/trial_{2}/'.format(config.solver, config.data, config.trial) if config.solver != 'w2' else \
                            './{0}_gen{2}_{1}/trial_{3}/'.format(config.solver, config.data, config.gen, config.trial)

print(dir_string)
exp_dir = dir_string
예제 #2
0
파일: main.py 프로젝트: jshe/wasserstein-2
def main():
    ## parse flags
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    exp_dir = os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders/generators & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks()
    utils.print_networks(networks)

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 100) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage iterations complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 100) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 500) == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(images, img_dir, step + 1, config)
    print("Training complete.")
    networks = model.get_networks()
    utils.save_networks(networks, model_dir)

    ## testing
    root = "./mvg_test"
    file = open(os.path.join(root, "data.pkl"), "rb")
    fixed_z = pickle.load(file)
    file.close()
    fixed_z = utils.to_var(fixed_z)
    fixed_gz = model.g(fixed_z).view(*fixed_z.size())
    utils.visualize_single(fixed_gz, os.path.join(img_dir, 'test.png'), config)
예제 #3
0
    parser.add_argument(
        "-c",
        "--config",
        help="Config file to use",
        default="shared/feature_pixelDA.yml",
    )
    args = Dict(vars(parser.parse_args()))

    # --------------------------
    # -----  Load Options  -----
    # --------------------------
    root = Path(__file__).parent.resolve()
    opts = load_opts(path=root / args.config, default=root / "shared/defaults.yml")
    opts = set_mode("train", opts)
    flats = flatten_opts(opts)
    print_opts(flats)

    # ------------------------------------
    # -----  Start Comet Experiment  -----
    # ------------------------------------
    wsp = args.get("workspace") or opts.comet.workspace
    prn = args.get("project_name") or opts.comet.project_name
    comet_exp = Experiment(workspace=wsp, project_name=prn)
    comet_exp.log_asset(file_data=str(root / args.config), file_name=root / args.config)
    comet_exp.log_parameters(flats)

    # ----------------------------
    # -----  Create loaders  -----
    # ----------------------------
    print("Creating loaders:")
    # ! important to do test first
args = AttrDict()
args_dict = {
              'cuda':False, 
              'nepochs':100, 
              'checkpoint_dir':"checkpoints", 
              'learning_rate':0.005, ## INCREASE BY AN ORDER OF MAGNITUDE
              'lr_decay':0.99,
              'batch_size':64, 
              'hidden_size':20, 
              'encoder_type': 'transformer',
              'decoder_type': 'transformer', # options: rnn / rnn_attention / transformer
              'num_transformer_layers': 3,
}
args.update(args_dict)

print_opts(args)
transformer_encoder, transformer_decoder = train(args)

translated = translate_sentence(TEST_SENTENCE, transformer_encoder, transformer_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))

"""Try translating different sentences by changing the variable TEST_SENTENCE. Identify two distinct failure modes and briefly describe them."""

TEST_SENTENCE = test_cases
translated = translate_sentence(TEST_SENTENCE, transformer_encoder, transformer_decoder, None, args)
print("source:\t\t{} \ntranslated:\t{}".format(TEST_SENTENCE, translated))
exit()

"""# Attention Visualizations

One of the benefits of using attention is that it allows us to gain insight into the inner workings of the model.
예제 #5
0
파일: main.py 프로젝트: lufeng22/OT-ICNN
def main():
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    dir_string = './{0}_{1}/trial_{2}/'.format(config.solver, config.data, config.trial) if config.solver != 'w2' else \
                                    './{0}_gen{2}_{1}/trial_{3}/'.format(config.solver, config.data, config.gen, config.trial)

    exp_dir = dir_string  #os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks(config)
    utils.print_networks(networks)

    fixed_r, fixed_z = model.get_fixed_data()
    utils.visualize_single(utils.to_data(fixed_z),
                           utils.to_data(fixed_r),
                           None,
                           os.path.join(img_dir, 'data.png'),
                           data_range=(-12,
                                       12) if config.data == '8gaussians' else
                           (-6, 6))
    if not config.no_benchmark:
        print('computing discrete-OT benchmark...')
        start_time = time.time()
        cost = model.get_cost()
        discrete_tz = utils.solve_assignment(fixed_z, fixed_r, cost,
                                             fixed_r.size(0))
        print('Done in %.4f seconds.' % (time.time() - start_time))
        utils.visualize_single(utils.to_data(fixed_z), utils.to_data(fixed_r),
                               utils.to_data(discrete_tz),
                               os.path.join(img_dir, 'assignment.png'))

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 10) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 10) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            if not config.no_benchmark:
                if config.gen:
                    stats['l2_dist/discrete_T_x--G_x'] = losses.calc_l2(
                        fixed_z, model.g(fixed_z), discrete_tz).data.item()
                else:
                    stats['l2_dist/discrete_T_x--T_x'] = losses.calc_l2(
                        fixed_z, model.get_tx(fixed_z, reverse=True),
                        discrete_tz).data.item()
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 10000) == 0 or step == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(
                images,
                img_dir,
                step + 1,
                config,
                data_range=(-12, 12) if config.data == '8gaussians' else
                (-6, 6))
    print("Training complete.")
    networks = model.get_networks(config)
    utils.save_networks(networks, model_dir)