Esempio n. 1
0
def part_two(ip):
    signal = ip * 10000
    offset = int("".join(list(map(str, signal[:7]))))
    empty = [0] * offset
    outp = []
    for i in range(0, 100):
        print_progress(100, i)
        outp = get_phase_output(signal, offset)
        signal = empty + outp

    return "".join(list(map(str, outp[:8])))
Esempio n. 2
0
def main():
    args = get_args()

    if not args.silent:
        save_path = os.path.abspath(script_path + args.save_path)
        if not os.path.exists(save_path):os.mkdir(save_path)
        save_path = os.path.abspath(save_path + "/" + args.name)
        if not os.path.exists(save_path):os.mkdir(save_path)
        preview_path = os.path.abspath(save_path + "/preview")
        if not os.path.exists(preview_path):os.mkdir(preview_path)

    dataset = Dataset(args)
    
    if args.max_epoch is not None:
        epoch_iter = dataset.train_data_len // args.batch_size
        if dataset.train_data_len % args.batch_size != 0:epoch_iter += 1
        args.max_iter = args.max_epoch * epoch_iter

    progress = print_progress(args.max_iter, args.batch_size, dataset.train_data_len)

    if args.gpu_num != 0:
        cuda.get_device_from_array(xp.array([i for i in range(args.gpu_num)])).use()
    model = make_model(args, dataset)
    netG_opt = make_optimizer(model.netG_0, args.adam_alpha, args.adam_beta1, args.adam_beta2)
    netD_opt = make_optimizer(model.netD_0, args.adam_alpha, args.adam_beta1, args.adam_beta2)

    updater = Updater(model, netG_opt, netD_opt, args.n_dis, args.batch_size, args.gpu_num, args.KL_loss_iter, args.KL_loss_conf, args.epoch_decay, args.max_iter)

    print("==========================================")
    print("Info:start train")
    start = time.time()
    for i in range(args.max_iter):

        data = toGPU(dataset.next(), args.gpu_num)
        updater.update(data, dataset.now_epoch)

        if dataset.now_iter % args.display_interval == 0:
            elapsed = time.time() - start
            progress(elapsed, dataset.get_state)
            np.save(save_path + "/loss_hist.npy", updater.loss_hist)
            start = time.time()
        
        if dataset.now_iter % args.snapshot_interval == 0 and not args.silent:
            data = dataset.sampling(args.sample_size)
            sample = sample_generate(model.netG_0, data, args.noise_dim, args.noise_dist)
            Image.fromarray(sample).save(preview_path + f"/image_{dataset.now_iter:08d}.png")
            serializers.save_npz(save_path + f"/Generator_{dataset.now_iter:08d}.npz",model.netG_0)
            serializers.save_npz(save_path + f"/Discriminator_{dataset.now_iter:08d}.npz",model.netD_0)
                    
    if not args.silent:
        data = dataset.sampling(args.sample_size)
        sample = sample_generate(model.netG_0, data, args.noise_dim, args.noise_dist)
        Image.fromarray(sample).save(preview_path + f"/image_{dataset.now_iter:08d}.png")
        serializers.save_npz(save_path + f"/Generator_{dataset.now_iter:08d}.npz",model.netG_0)
        serializers.save_npz(save_path + f"/Discriminator_{dataset.now_iter:08d}.npz",model.netD_0)
    print("\n\n\n\n==========================================")
    print("Info:finish train")
Esempio n. 3
0
def main():
    conf = get_args()

    if not conf.silent:
        save_path = os.path.abspath(script_path + conf.save_path)
        if not os.path.exists(save_path): os.mkdir(save_path)
        preview_path = os.path.abspath(save_path + "/preview")
        if not os.path.exists(preview_path): os.mkdir(preview_path)

    dataset = MSCOCO(conf)
    netG, netD = build_models(conf)
    optimizerG, optimizerD = build_optimizer(netG, netD, conf.adam_lr,
                                             conf.adam_beta1, conf.adam_beta2)
    pprog = print_progress(conf.max_epoch,
                           conf.batch_size,
                           dataset.train_data_len,
                           use_epoch=True)

    updater = Updater(netG, netD, optimizerG, optimizerD, conf)

    print("==========================================")
    print("Info:start train")

    val_times = dataset.val_data_len // dataset.batch_size
    if dataset.val_data_len % dataset.batch_size != 0: val_times += 1
    for i in range(conf.max_epoch):
        train_loss = np.array([0., 0.], dtype="float32")
        start = time.time()
        for data in dataset.get_data():
            data = toGPU(data, conf.gpu_num)
            updater.update(data, i)

            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                pprog(elapsed, dataset.get_state)
                start = time.time()

        if i % conf.snapshot_interval == 0 and not conf.silent:
            data = dataset.sampling(conf.sample_size)
            sample = sample_generate(netG, data, conf)
            Image.fromarray(sample).save(preview_path + f"/image_{i:04d}.png")

    print("\n\n\n\n==========================================")
    print("Info:finish train")
Esempio n. 4
0
def batch_predict(net: NetShiva, dss: List[DataSource],
                  train_config: TrainConfig, eval_config: EvalConfig):
    data_ranges = []
    data_points = 0
    for ds in dss:
        r = ds.get_data_range(eval_config.BEG, eval_config.END)
        b, e = r
        data_ranges.append(r)
        if b is not None and e is not None:
            if (e - b) > data_points:
                data_points = e - b

    batch_size = len(dss)
    # reset state
    curr_progress = 0
    processed = 0

    num_features = len(train_config.FEATURE_FUNCTIONS)
    input = np.zeros([batch_size, train_config.BPTT_STEPS, num_features])
    mask = np.ones([batch_size, train_config.BPTT_STEPS])
    labels = np.zeros([batch_size, train_config.BPTT_STEPS])
    seq_len = np.zeros([batch_size], dtype=np.int32)

    batches = data_points // eval_config.BPTT_STEPS if data_points % eval_config.BPTT_STEPS == 0 else data_points // eval_config.BPTT_STEPS + 1

    state = net.zero_state(batch_size=batch_size)

    predictions_history = np.zeros(
        [batch_size, batches * eval_config.BPTT_STEPS])

    total_seq_len = np.zeros([batch_size], dtype=np.int32)
    for ds_idx in range(len(dss)):
        beg_data_idx, end_data_idx = data_ranges[ds_idx]
        if beg_data_idx is None or end_data_idx is None:
            continue
        t_s_l = end_data_idx - beg_data_idx
        total_seq_len[ds_idx] = t_s_l

    for b in range(batches):

        for ds_idx in range(len(dss)):
            ds = dss[ds_idx]
            beg_data_idx, end_data_idx = data_ranges[ds_idx]
            if beg_data_idx is None or end_data_idx is None:
                continue

            b_d_i = beg_data_idx + b * eval_config.BPTT_STEPS
            e_d_i = beg_data_idx + (b + 1) * eval_config.BPTT_STEPS
            b_d_i = min(b_d_i, end_data_idx)
            e_d_i = min(e_d_i, end_data_idx)

            s_l = e_d_i - b_d_i
            seq_len[ds_idx] = s_l

            for f in range(num_features):
                input[ds_idx, :s_l,
                      f] = train_config.FEATURE_FUNCTIONS[f](ds, b_d_i, e_d_i)

            labels[ds_idx, :s_l] = train_config.LABEL_FUNCTION(
                ds, b_d_i, e_d_i)

        state, sse, predictions = net.eval(state, input, labels,
                                           mask.astype(np.float32), seq_len)

        predictions_history[:, b * eval_config.BPTT_STEPS:(b + 1) *
                            eval_config.BPTT_STEPS] = predictions[:, :, 0]

        if math.isnan(sse):
            raise "Nan"

        # TODO: not absolutelly correct
        processed += eval_config.BPTT_STEPS
        curr_progress = progress.print_progress(curr_progress, processed,
                                                data_points)

    weak_predictions = np.zeros([batch_size])

    for j in range(batch_size):
        weak_predictions[j] = predictions_history[j, total_seq_len[j] - 1]

    progress.print_progess_end()
    return weak_predictions
Esempio n. 5
0
def stacking_net_predict(net: NetShiva, dss: List[DataSource],
                         train_config: TrainConfig, eval_config: EvalConfig,
                         weak_predictors):
    data_ranges = []
    total_length = 0
    for ds in dss:
        r = ds.get_data_range(eval_config.BEG, eval_config.END)
        b, e = r
        data_ranges.append(r)
        if b is not None and e is not None:
            total_length += e - b

    batch_size = 1
    # reset state
    curr_progress = 0
    processed = 0
    total_sse = 0
    total_sse_members = 0

    num_features = len(train_config.FEATURE_FUNCTIONS)
    input = np.zeros([batch_size, train_config.BPTT_STEPS, num_features])
    mask = np.ones([batch_size, train_config.BPTT_STEPS])
    labels = np.zeros([batch_size, train_config.BPTT_STEPS])
    ts = np.zeros([batch_size, train_config.BPTT_STEPS])

    predictions_history = []

    for ds_idx in range(len(dss)):
        ds = dss[ds_idx]
        state = net.zero_state(batch_size=batch_size)
        beg_data_idx, end_data_idx = data_ranges[ds_idx]
        if beg_data_idx is None or end_data_idx is None:
            predictions_history.append(None)
            continue

        ds.load_weak_predictions(weak_predictors, beg_data_idx, end_data_idx)

        data_points = end_data_idx - beg_data_idx

        p_h = np.zeros([data_points, 3])

        batches = data_points // eval_config.BPTT_STEPS if data_points % eval_config.BPTT_STEPS == 0 else data_points // eval_config.BPTT_STEPS + 1
        for b in range(batches):
            b_d_i = beg_data_idx + b * eval_config.BPTT_STEPS
            e_d_i = beg_data_idx + (b + 1) * eval_config.BPTT_STEPS
            e_d_i = min(e_d_i, end_data_idx)

            seq_len = e_d_i - b_d_i

            # for f in range(num_features):
            #     input[0, :seq_len, f] = train_config.FEATURE_FUNCTIONS[f](ds, b_d_i, e_d_i)
            input[0, :seq_len, :] = np.transpose(
                ds.get_weak_predictions(b_d_i, e_d_i))

            labels[0, :seq_len] = train_config.LABEL_FUNCTION(ds, b_d_i, e_d_i)

            ts[0, :seq_len] = ds.get_ts(b_d_i, e_d_i)

            if seq_len < eval_config.BPTT_STEPS:
                _input = input[:, :seq_len, :]
                _labels = labels[:, :seq_len]
                _mask = mask[:, :seq_len]
                _ts = ts[:, :seq_len]

            else:
                _input = input
                _labels = labels
                _mask = mask
                _ts = ts

            state, sse, predictions = net.eval(state, _input, _labels,
                                               _mask.astype(np.float32))

            b_i = b_d_i - beg_data_idx
            e_i = e_d_i - beg_data_idx

            p_h[b_i:e_i, 0] = _ts
            p_h[b_i:e_i, 1] = predictions.reshape([-1])
            p_h[b_i:e_i, 2] = _labels

            if math.isnan(sse):
                raise "Nan"
            total_sse += sse
            total_sse_members += np.sum(_mask)
            processed += seq_len
            curr_progress = progress.print_progress(curr_progress, processed,
                                                    total_length)

        ds.unload_weak_predictions()
        predictions_history.append(p_h)

    progress.print_progess_end()
    avg_loss = math.sqrt(total_sse / total_sse_members)
    return avg_loss, predictions_history
def main():
    io = Moon(Point(4, 12, 13), Velocity(0, 0, 0))
    europa = Moon(Point(-9, 14, -3), Velocity(0, 0, 0))
    ganymede = Moon(Point(-7, -1, 2), Velocity(0, 0, 0))
    callisto = Moon(Point(-11, 17, -1), Velocity(0, 0, 0))

    # Prasanna
    # io = Moon(Point(-19, -4, 2), Velocity(0, 0, 0))
    # europa = Moon(Point(-9, 8, -16), Velocity(0, 0, 0))
    # ganymede = Moon(Point(-4, 5, -11), Velocity(0, 0, 0))
    # callisto = Moon(Point(1, 9, -13), Velocity(0, 0, 0))

    # example
    # io = Moon(Point(-1, 0, 2), Velocity(0, 0, 0))
    # europa = Moon(Point(2, -10, -7), Velocity(0, 0, 0))
    # ganymede = Moon(Point(4, -8, 8), Velocity(0, 0, 0))
    # callisto = Moon(Point(3, 5, -1), Velocity(0, 0, 0))

    # example 2
    # io = Moon(Point(-8, -10, 0), Velocity(0, 0, 0))
    # europa = Moon(Point(5, 5, 10), Velocity(0, 0, 0))
    # ganymede = Moon(Point(2, -7, 3), Velocity(0, 0, 0))
    # callisto = Moon(Point(9, -8, -3), Velocity(0, 0, 0))

    moons = [io, europa, ganymede, callisto]
    initial_state = deepcopy(moons)

    count = 1
    repeat = {'x': 0, 'y': 0, 'z': 0}
    progress_counter = 0

    while True:
        print_progress(3, progress_counter)

        for i in range(0, len(moons)):
            for j in range(0, len(moons)):
                if i != j:
                    moons[i].calc_velocity(moons[j])

        for m in moons:
            m.move()

        if repeat['x'] == 0:
            x_flag = True
            for i in range(0, 4):
                if not initial_state[i].compare_x(moons[i]):
                    x_flag = False
                    break
            if x_flag:
                repeat['x'] = count
                progress_counter += 1

        if repeat['y'] == 0:
            y_flag = True
            for i in range(0, 4):
                if not initial_state[i].compare_y(moons[i]):
                    y_flag = False
                    break
            if y_flag:
                repeat['y'] = count
                progress_counter += 1

        if repeat['z'] == 0:
            z_flag = True
            for i in range(0, 4):
                if not initial_state[i].compare_z(moons[i]):
                    z_flag = False
                    break
            if z_flag:
                repeat['z'] = count
                progress_counter += 1

        loop_break = True
        for i in repeat:
            if repeat[i] == 0:
                loop_break = False

        if loop_break:
            break

        count += 1

    x = repeat['x']
    y = repeat['y']
    z = repeat['z']

    ans = int(lcm(x, y, z))

    return ans
Esempio n. 7
0
def main():
    conf = get_args()
    dataset = MSCOCO(conf)
    VOC_SIZE = dataset.jp_voc_size if conf.use_lang == "jp" else dataset.en_voc_size
    SAMPLE_SIZE = conf.sample_size // conf.gpu_num if conf.gpu_num > 1 else conf.sample_size
    SEQ_LEN = conf.seq_len_jp if conf.use_lang == "jp" else conf.seq_len_en
    index2tok = dataset.jp_index2tok if conf.use_lang == "jp" else dataset.en_index2tok

    if not conf.silent:
        save_path = os.path.abspath(script_path + conf.save_path)
        if not os.path.exists(save_path): os.mkdir(save_path)
        save_path = os.path.abspath(save_path + f"/{conf.use_lang}")
        if not os.path.exists(save_path): os.mkdir(save_path)
        preview_path = os.path.abspath(save_path + "/preview")
        if not os.path.exists(preview_path): os.mkdir(preview_path)

    netG, netD = build_models(conf, VOC_SIZE, SEQ_LEN)
    optimizerG, optimizerD = build_optimizer(netG, netD, conf.adam_lr,
                                             conf.adam_beta1, conf.adam_beta2)
    pprog = print_progress(conf.pre_gen_max_epoch, conf.batch_size,
                           dataset.train_data_len)

    updater = Updater(netG, netD, optimizerG, optimizerD, conf)

    def pretrain_generatr():
        print("==========================================")
        print("Info:start genarator pre train")
        pre_gen_loss_hist = np.zeros((1, conf.pre_gen_max_epoch),
                                     dtype="float32")
        for i in range(conf.pre_gen_max_epoch):
            count = 0
            total_loss = 0
            start = time.time()
            for data in dataset.get_data():
                data = toGPU(data, conf.gpu_num)
                loss = updater.update_pre_gen(data)

                total_loss += loss.data.cpu().numpy()

                count += 1
                if dataset.now_iter % conf.display_interval == 0:
                    elapsed = time.time() - start
                    pprog(elapsed, dataset.get_state)
                    start = time.time()

            pre_gen_loss_hist[0, i] = total_loss / count

        if not conf.silent:
            data = dataset.sample(conf.sample_size)
            sample_generate(netG, data, SAMPLE_SIZE, index2tok, conf.gpu_num,\
                        conf.noise_dim, preview_path + f"/sample_text_pretrain.txt")
            np.save(save_path + "/pre_gen_loss_hist", pre_gen_loss_hist)
            torch.save(netG.state_dict(), save_path + "/pretrain_gen_params")
        print("\n\n\n\n==========================================")

    def pretrain_discriminator():
        print("==========================================")
        print("Info:start discriminator pre train")
        dataset.clear_state()
        pprog.max_iter = conf.pre_dis_max_epoch
        pre_dis_hist = np.zeros((4, conf.pre_dis_max_epoch), dtype="float32")
        for i in range(conf.pre_dis_max_epoch):
            count = 0
            total_loss = 0
            total_real_acc = 0
            total_fake_acc = 0
            total_wrong_acc = 0
            start = time.time()
            for data in dataset.get_data():
                data = toGPU(data, conf.gpu_num)
                loss, real_acc, fake_acc, wrong_acc = updater.update_dis(data)

                total_loss += loss.data.cpu().numpy()
                total_real_acc += real_acc.data.cpu().numpy()
                total_fake_acc += fake_acc.data.cpu().numpy()
                total_wrong_acc += wrong_acc.data.cpu().numpy()

                count += 1
                if dataset.now_iter % conf.display_interval == 0:
                    elapsed = time.time() - start
                    pprog(elapsed, dataset.get_state)
                    start = time.time()

            pre_dis_hist[0, i] = total_loss / count
            pre_dis_hist[1, i] = total_real_acc / count
            pre_dis_hist[2, i] = total_fake_acc / count
            pre_dis_hist[3, i] = total_wrong_acc / count

        if not conf.silent:
            np.save(save_path + "/pre_dis_hist", pre_dis_hist)
            torch.save(netD.state_dict(), save_path + "/pretrain_dis_params")
        print("\n\n\n\n==========================================")

    if os.path.exists(save_path + "/pretrain_gen_params"):
        netG.load_state_dict(torch.load(save_path + "/pretrain_gen_params"))
    else:
        pretrain_generatr()

    if os.path.exists(save_path + "/pretrain_dis_params"):
        netD.load_state_dict(torch.load(save_path + "/pretrain_dis_params"))
    else:
        pretrain_discriminator()

    print("==========================================")
    print("Info:start main train")
    dataset.clear_state()
    pprog.max_iter = conf.max_epoch
    train_loss_hist = np.zeros((5, conf.max_epoch), dtype="float32")
    val_loss_hist = np.zeros((5, conf.max_epoch), dtype="float32")
    val_count = dataset.val_data_len // conf.batch_size
    if dataset.val_data_len % conf.batch_size != 1: val_count += 1
    for i in range(conf.max_epoch):
        #train loop
        count = 1
        total_g_loss = 0
        total_d_loss = 0
        total_real_acc = 0
        total_fake_acc = 0
        total_wrong_acc = 0
        start = time.time()

        for p in netG.parameters():
            p.requires_grad = True
        for p in netD.parameters():
            p.requires_grad = True
        for data in dataset.get_data():
            data = toGPU(data, conf.gpu_num)

            if count % conf.n_dis == 0:
                loss = updater.update_PG(data)
                total_g_loss += loss.data.cpu().numpy()

            loss, real_acc, fake_acc, wrong_acc = updater.update_dis(data)

            total_d_loss += loss.data.cpu().numpy()
            total_real_acc += real_acc.data.cpu().numpy()
            total_fake_acc += fake_acc.data.cpu().numpy()
            total_wrong_acc += wrong_acc.data.cpu().numpy()

            count += 1
            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                pprog(elapsed, dataset.get_state)
                start = time.time()

        train_loss_hist[0, i] = total_d_loss / count
        train_loss_hist[1, i] = total_real_acc / count
        train_loss_hist[2, i] = total_fake_acc / count
        train_loss_hist[3, i] = total_wrong_acc / count
        train_loss_hist[4, i] = total_g_loss / (count // 5)
        print("\n\n\n")

        #val loop
        print(f"Validation {i+1} / {conf.max_epoch}")
        count = 0
        total_g_loss = 0
        total_d_loss = 0
        total_real_acc = 0
        total_fake_acc = 0
        total_wrong_acc = 0
        start = time.time()
        for p in netG.parameters():
            p.requires_grad = False
        for p in netD.parameters():
            p.requires_grad = False
        for data in dataset.get_data(is_val=True):
            data = toGPU(data, conf.gpu_num)

            g_loss, d_loss, real_acc, fake_acc, wrong_acc = updater.evaluate(
                data)

            count += 1
            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                progress(count + 1, val_count, elapsed)

        progress(count, val_count, elapsed)
        val_loss_hist[0, i] = total_d_loss / count
        val_loss_hist[1, i] = total_real_acc / count
        val_loss_hist[2, i] = total_fake_acc / count
        val_loss_hist[3, i] = total_wrong_acc / count
        val_loss_hist[4, i] = total_g_loss / (count // 5)
        print("\u001B[5A", end="")

        if (i + 1) % conf.snapshot_interval == 0 and not conf.silent:
            data = dataset.sample(conf.sample_size)
            sample_generate(netG, data, SAMPLE_SIZE, index2tok, conf.gpu_num,\
                        conf.noise_dim, preview_path + f"/sample_text_{i+1:04d}.txt")
            np.save(save_path + "/train_loss_hist", train_loss_hist)
            np.save(save_path + "/val_loss_hist", val_loss_hist)
            torch.save(netG.state_dict(),
                       save_path + f"/gen_params_{i+1:04d}.pth")
            torch.save(netD.state_dict(),
                       save_path + f"/dis_params_{i+1:04d}.pth")

    if not conf.silent:
        np.save(save_path + "/train_loss_hist", train_loss_hist)
        np.save(save_path + "/val_loss_hist", val_loss_hist)
        data = dataset.sample(conf.sample_size)
        sample_generate(netG, data, SAMPLE_SIZE, index2tok, conf.gpu_num,\
                    conf.noise_dim, preview_path + "/sample_text.txt")
        torch.save(netG.state_dict(), save_path + "/gen_params.pth")
        torch.save(netD.state_dict(), save_path + "/dis_params.pth")
    print("\n\n\n\n==========================================")
    print("Info:finish train")
Esempio n. 8
0
def main():
    conf = get_args()
    dataset = MSCOCO(conf)
    VOC_SIZE = dataset.jp_voc_size if conf.use_lang == "jp" else dataset.en_voc_size
    SAMPLE_SIZE = conf.sample_size // conf.gpu_num if conf.gpu_num > 1 else conf.sample_size
    SEQ_LEN = conf.seq_len_jp if conf.use_lang == "jp" else conf.seq_len_en
    index2tok = dataset.jp_index2tok if conf.use_lang == "jp" else dataset.en_index2tok

    if not conf.silent:
        save_path = os.path.abspath(script_path + conf.save_path)
        if not os.path.exists(save_path): os.mkdir(save_path)
        save_path = os.path.abspath(save_path + f"/{conf.use_lang}")
        if not os.path.exists(save_path): os.mkdir(save_path)
        preview_path = os.path.abspath(save_path + "/preview")
        if not os.path.exists(preview_path): os.mkdir(preview_path)

    netG, netD = build_models(conf, VOC_SIZE, SEQ_LEN)
    optimizerG, optimizerD = build_optimizer(netG, netD, conf.adam_lr,
                                             conf.adam_beta1, conf.adam_beta2)
    pprog = print_progress(conf.pre_gen_max_epoch, conf.batch_size,
                           dataset.train_data_len)

    updater = Updater(netG, netD, optimizerG, optimizerD, conf)
    """
    print("==========================================")
    print("Info:start genarator pre train")
    pre_gen_loss_hist = np.zeros((1, conf.pre_gen_max_epoch), dtype="float32")
    for i in range(conf.pre_gen_max_epoch):
        count = 0
        total_loss = np.array([0.], dtype="float32")
        start = time.time()
        for data in dataset.get_data():
            break
            data = toGPU(data, conf.gpu_num)
            loss = updater.update_pre_gen(data)

            total_loss += loss.data.cpu().numpy()

            count += 1
            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                pprog(elapsed, dataset.get_state)
                start = time.time()

        pre_gen_loss_hist[0,i] = total_loss / count

    if not conf.silent:
        sample_generate(netG, SAMPLE_SIZE, index2tok, preview_path + f"/sample_text_pretrain.txt")
        np.save(save_path + "/pre_gen_loss_hist", pre_gen_loss_hist)
        torch.save(netG.state_dict(), save_path + "/pretrain_gen_params")
    print("\n\n\n\n==========================================")


    print("==========================================")
    print("Info:start discriminator pre train")
    dataset.clear_state()
    pprog.max_iter = conf.pre_dis_max_epoch
    pre_dis_hist = np.zeros((2, conf.pre_dis_max_epoch), dtype="float32")
    for i in range(conf.pre_dis_max_epoch):
        count = 0
        total_loss = np.array([0.], dtype="float32")
        total_acc = np.array([0.], dtype="float32")
        start = time.time()
        for data in dataset.get_data():
            data = toGPU(data, conf.gpu_num)
            loss, acc = updater.update_dis(data)

            total_loss += loss.data.cpu().numpy()
            total_acc += acc.data.cpu().numpy()

            count += 1
            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                pprog(elapsed, dataset.get_state)
                start = time.time()

        pre_dis_hist[0,i] = total_loss / count
        pre_dis_hist[1,i] = total_acc / count

    if not conf.silent:
        np.save(save_path + "/pre_dis_hist", pre_dis_hist)
        torch.save(netD.state_dict(), save_path + "/pretrain_dis_params")
    print("\n\n\n\n==========================================")
    """

    print("==========================================")
    print("Info:start main train")
    dataset.clear_state()
    pprog.max_iter = conf.max_epoch
    loss_hist = np.zeros((3, conf.max_epoch), dtype="float32")
    for i in range(conf.max_epoch):
        count = 0
        total_g_loss = np.array([0.], dtype="float32")
        total_d_loss = np.array([0.], dtype="float32")
        total_acc = np.array([0.], dtype="float32")
        start = time.time()
        for data in dataset.get_data():
            data = toGPU(data, conf.gpu_num)

            if count % conf.n_dis == 0:
                loss = updater.update_PG(data)
                total_g_loss += loss.data.cpu().numpy()

            loss, acc = updater.update_dis(data)
            total_d_loss += loss.data.cpu().numpy()
            total_acc += loss.data.cpu().numpy()

            count += 1
            if dataset.now_iter % conf.display_interval == 0:
                elapsed = time.time() - start
                pprog(elapsed, dataset.get_state)
                start = time.time()

        loss_hist[0, i] = total_d_loss / count
        loss_hist[1, i] = total_acc / count
        loss_hist[2, i] = total_g_loss / (count // 5)

        if i % conf.snapshot_interval == 0 and not conf.silent:
            sample_generate(netG, SAMPLE_SIZE, index2tok,
                            preview_path + f"/sample_text_{i:04d}.txt")
            np.save(save_path + "/loss_hist", loss_hist)
            torch.save(netG.state_dict(),
                       save_path + f"/gen_params_{i:04d}.pth")
            torch.save(netD.state_dict(),
                       save_path + f"/dis_params_{i:04d}.pth")

    if not conf.silent:
        np.save(save_path + "/loss_hist", loss_hist)
        sample_generate(netG, SAMPLE_SIZE, index2tok,
                        preview_path + "/sample_text.txt")
        torch.save(netG.state_dict(), save_path + "/gen_params.pth")
        torch.save(netD.state_dict(), save_path + "/dis_params.pth")
    print("\n\n\n\n==========================================")
    print("Info:finish train")