Пример #1
0
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/freq', exist_ok=True)
    save_path = f'trained_model/freq/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        beta_time_scale=cfg['MODEL']['BETA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
        sigma_syn=cfg['MODEL']['SIGMA_SYN'],
        use_bias=cfg['MODEL']['USE_BIAS'],
        anti_hebbian=cfg['MODEL']['ANTI_HEBB']).to(device)

    train_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    print(model)
    print('Epoch Loss Acc')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg['TRAIN']['LR'],
                           weight_decay=cfg['TRAIN']['WEIGHT_DECAY'])
    correct = 0
    num_data = 0
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden, new_j = model(inputs, hidden)
            # print(output)

            loss = torch.nn.CrossEntropyLoss()(output[:, -1], target)
            dummy_zero = torch.zeros([
                cfg['TRAIN']['BATCHSIZE'],
                cfg['DATALOADER']['TIME_LENGTH'] + 1, cfg['MODEL']['SIZE']
            ]).float().to(device)
            active_norm = torch.nn.MSELoss()(hidden_list, dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            optimizer.step()
            correct += (np.argmax(
                output[:, -1].cpu().detach().numpy(),
                axis=1) == target.cpu().detach().numpy()).sum().item()
            num_data += target.cpu().detach().numpy().shape[0]

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            acc = correct / num_data
            print(f'{epoch}, {loss.item():.6f}, {acc:.6f}')
            print(active_norm)
            # print('w_hh: ', model.w_hh.weight.cpu().detach().numpy()[:4, :4])
            # print('new_j: ', new_j.cpu().detach().numpy()[0, :4, :4])
            correct = 0
            num_data = 0
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))
Пример #2
0
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/freq_deepr', exist_ok=True)
    save_path = f'trained_model/freq_deepr/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
    ).to(device)

    train_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'],
    )

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed(),
    )

    print(model)
    # print(model.state_dict())
    print('Epoch Loss Acc')

    correct = 0
    num_data = 0
    num_connection = 13000
    b = 1
    flag1 = True
    flag2 = True
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            # print(i)
            inputs, target = data
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            hidden = hidden.detach()
            hidden_list, output, hidden = model(inputs, hidden)

            loss = torch.nn.CrossEntropyLoss()(output[:, -1], target)
            dummy_zero = torch.zeros([
                cfg['TRAIN']['BATCHSIZE'],
                cfg['DATALOADER']['TIME_LENGTH'] + 1, cfg['MODEL']['SIZE']
            ], ).float().to(device)
            active_norm = torch.nn.MSELoss()(hidden_list, dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            # print('!!!', model.theta.data[:10, :10])
            # print(model.abs_w_0.grad.data[:10, :10])
            # print(model.w_0.grad.data)
            # print(model.tensor_is_con_0[:10, :10]
            for j, param in enumerate(model.parameters()):
                param.data -= cfg['TRAIN']['LR'] * param.grad.data
            model.abs_w_0.data = model.abs_w_0.data - cfg['TRAIN']['LR'] * model.abs_w_0.grad.data + \
                torch.randn_like(model.abs_w_0) * (cfg['TRAIN']['LR'] * 0.1 * b) - \
                cfg['TRAIN']['LR'] * model.abs_w_0.data * (cfg['TRAIN']['LR'] * 5)
            # model.abs_w_0.data = torch.zeros((256, 256))
            # print(model.abs_w_0.data == model.abs_w_0.data - cfg['TRAIN']['LR'] * model.abs_w_0.grad.data)
            correct += (np.argmax(
                output[:, -1].cpu().detach().numpy(),
                axis=1) == target.cpu().detach().numpy()).sum().item()
            num_data += target.cpu().detach().numpy().shape[0]
            num_reconnect = num_connection - np.count_nonzero(
                model.theta.detach().cpu().numpy() > 0)
            # print(np.count_nonzero(model.theta.detach().cpu().numpy() > 0))
            # print(num_reconnect)
            if num_reconnect > 0:
                # below_zero_index = model.theta.detach().cpu().numpy() < 0
                nonzero_index = model.theta.detach().cpu().numpy() > 0
                # print(nonzero_index[:10, :10])
                # print(model.tensor_is_con_0[:10, :10])
                # model.abs_w_0.data[below_zero_index] = 0
                model.tensor_is_con_0 *= torch.from_numpy(
                    nonzero_index.astype(np.int)).float()
                # print(np.count_nonzero(zero_index))
                # print(np.count_nonzero(model.tensor_is_con_0.detach().cpu().numpy()))
                candidate_connection = list(
                    zip(*np.where(model.theta.detach().cpu().numpy() <= 0)))
                np.random.shuffle(candidate_connection)
                for j in range(num_reconnect):
                    model.tensor_is_con_0[candidate_connection[j]] = 1
                    model.abs_w_0.data[candidate_connection[j]] = 0.00001

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            print(model.w_hh.data[:10, :10])
            acc = correct / num_data
            print(f'{epoch}, {loss.item():.6f}, {acc:.6f}')
            print(active_norm)
            if active_norm.item() > 0.05 and flag1:
                cfg['TRAIN']['LR'] *= 0.1
                cfg['TRAIN']['ACTIVATION_LAMBDA'] *= 5
                flag1 = False
            if active_norm.item() > 0.1 and flag2:
                cfg['TRAIN']['LR'] *= 0.1
                cfg['TRAIN']['ACTIVATION_LAMBDA'] *= 5
                b = 0
                flag2 = False
            correct = 0
            num_data = 0
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))
            np.save(os.path.join(save_path, f'w_sign_{epoch}.npy'),
                    model.w_sign.detach().cpu().numpy())
            np.save(os.path.join(save_path, f'tensor_is_con_{epoch}.npy'),
                    model.tensor_is_con_0.detach().cpu().numpy())
            np.save(os.path.join(save_path, f'abs_w_0_{epoch}.npy'),
                    model.abs_w_0.detach().cpu().numpy())
Пример #3
0
def main(config_path, model_epoch):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('fixed_points', exist_ok=True)
    os.makedirs('fixed_points/freq', exist_ok=True)
    save_path = f'fixed_points/freq/{model_name}_{model_epoch}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    # cfg['DATALOADER']['TIME_LENGTH'] = 200
    # cfg['DATALOADER']['SIGNAL_LENGTH'] = 50
    cfg['DATALOADER']['VARIABLE_DELAY'] = 0

    # model load
    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        beta_time_scale=cfg['MODEL']['BETA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
        sigma_syn=cfg['MODEL']['SIGMA_SYN'],
        use_bias=cfg['MODEL']['USE_BIAS'],
        anti_hebbian=cfg['MODEL']['ANTI_HEBB']).to(device)

    model_path = f'trained_model/freq/{model_name}/epoch_{model_epoch}.pth'
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    eval_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    analyzer = FixedPoint(model=model,
                          device=device,
                          alpha=cfg['MODEL']['ALPHA'],
                          max_epochs=140000)

    for trial in range(50):
        for i, data in enumerate(eval_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            hidden = hidden.detach()
            hidden_list, output, hidden, _ = model(inputs, hidden)

            # const_signal = torch.tensor([0] * 1)
            # const_signal = const_signal.float().to(device)

            reference_time_point = np.random.randint(35, 55)
            fixed_point, result_ok = analyzer.find_fixed_point(
                hidden_list[0, reference_time_point], view=True)

            fixed_point = fixed_point.detach().cpu().numpy()

            # print(fixed_point)
            # fixed_point_tensor = torch.from_numpy(fixed_point).float()
            # jacobian = analyzer.calc_jacobian(fixed_point_tensor, const_signal)

            # print(np.dot(model.w_out.weight.detach().cpu().numpy(), fixed_point))

            # w, v = np.linalg.eig(jacobian)
            # print('eigenvalues', w)

            np.savetxt(os.path.join(save_path, f'fixed_point_{trial}_{i}.txt'),
                       fixed_point)
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/freq_schedule2', exist_ok=True)
    save_path = f'trained_model/freq_schedule2/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        beta_time_scale=cfg['MODEL']['BETA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
        sigma_syn=cfg['MODEL']['SIGMA_SYN'],
        use_bias=cfg['MODEL']['USE_BIAS'],
        anti_hebbian=cfg['MODEL']['ANTI_HEBB']).to(device)

    train_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    print(model)
    print('Epoch Loss Acc')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg['TRAIN']['LR'],
                           weight_decay=cfg['TRAIN']['WEIGHT_DECAY'])
    correct = 0
    num_data = 0
    phase2 = False
    phase3 = False
    phase4 = False
    phase5 = False
    phase6 = False
    if 'PHASE_TRANSIT' in cfg['TRAIN'].keys():
        phase_transition_criteria = cfg['TRAIN']['PHASE_TRANSIT']
    else:
        phase_transition_criteria = [0.5, 0.45, 0.4, 0.3, 0.2]
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden, new_j = model(inputs, hidden)
            # print(output)

            loss = torch.nn.CrossEntropyLoss()(output[:, -1], target)
            dummy_zero = torch.zeros([
                cfg['TRAIN']['BATCHSIZE'],
                int(cfg['DATALOADER']['TIME_LENGTH'] -
                    2 * cfg['DATALOADER']['SIGNAL_LENGTH']),
                cfg['MODEL']['SIZE']
            ]).float().to(device)
            active_norm = torch.nn.MSELoss()(
                hidden_list[:, cfg['DATALOADER']['SIGNAL_LENGTH']:
                            cfg['DATALOADER']['TIME_LENGTH'] -
                            cfg['DATALOADER']['SIGNAL_LENGTH'], :], dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            optimizer.step()
            correct += (np.argmax(
                output[:, -1].cpu().detach().numpy(),
                axis=1) == target.cpu().detach().numpy()).sum().item()
            num_data += target.cpu().detach().numpy().shape[0]

            if not phase2 and float(
                    loss.item()) < phase_transition_criteria[0]:
                cfg['MODEL']['ALPHA'] = 0.2
                cfg['DATALOADER']['TIME_LENGTH'] = 95
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 20
                cfg['DATALOADER']['VARIABLE_DELAY'] = 6

                print("phase2 start! cfg['MODEL']['ALPHA'] = 0.2")
                phase2 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase3 and float(
                    loss.item()) < phase_transition_criteria[1]:
                cfg['MODEL']['ALPHA'] = 0.175
                cfg['DATALOADER']['TIME_LENGTH'] = 110
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 22
                cfg['DATALOADER']['VARIABLE_DELAY'] = 7

                print("phase3 start! cfg['MODEL']['ALPHA'] = 0.175")
                phase3 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase4 and float(
                    loss.item()) < phase_transition_criteria[2]:
                cfg['MODEL']['ALPHA'] = 0.15
                cfg['DATALOADER']['TIME_LENGTH'] = 120
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 25
                cfg['DATALOADER']['VARIABLE_DELAY'] = 8

                print("phase4 start! cfg['MODEL']['ALPHA'] = 0.15")
                phase4 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase5 and float(
                    loss.item()) < phase_transition_criteria[3]:
                cfg['MODEL']['ALPHA'] = 0.10
                cfg['DATALOADER']['TIME_LENGTH'] = 160
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 30
                cfg['DATALOADER']['VARIABLE_DELAY'] = 10

                print("phase5 start! cfg['MODEL']['ALPHA'] = 0.1")
                phase5 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase6 and float(
                    loss.item()) < phase_transition_criteria[4]:
                cfg['MODEL']['ALPHA'] = 0.075
                cfg['DATALOADER']['TIME_LENGTH'] = 267
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 50
                cfg['DATALOADER']['VARIABLE_DELAY'] = 15

                print("phase6 start! cfg['MODEL']['ALPHA'] = 0.075")
                phase6 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            acc = correct / num_data
            print(f'{epoch}, {loss.item():.6f}, {acc:.6f}')
            print(f'activation norm: {active_norm.item():.4f}, time scale: , '
                  f'{model.alpha.detach().cpu().numpy()[0]:.3f}')
            correct = 0
            num_data = 0
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))