def main(activation):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    os.makedirs('trained_model', exist_ok=True)
    save_path = f'trained_model/{activation}'
    os.makedirs(save_path, exist_ok=True)

    model = RecurrentNeuralNetwork(n_in=1,
                                   n_out=1,
                                   n_hid=200,
                                   device=device,
                                   activation=activation,
                                   sigma=0,
                                   use_bias=True).to(device)

    train_dataset = SineWave(freq_range=51, time_length=40)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=50,
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    print(model)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=0.001,
                           weight_decay=0.0001)

    for epoch in range(2001):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target, = data
            inputs, target, = inputs.float(), target.float()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(50, 200)
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden = model(inputs, hidden)

            loss = torch.nn.MSELoss()(output, target)
            loss.backward()
            optimizer.step()

        if epoch > 0 and epoch % 200 == 0:
            print(f'Train Epoch: {epoch}, Loss: {loss.item():.6f}')
            print('output', output[0, :, 0].cpu().detach().numpy())
            print('target', target[0, :, 0].cpu().detach().numpy())
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))
Example #2
0
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/freq', exist_ok=True)
    save_path = f'trained_model/freq/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        beta_time_scale=cfg['MODEL']['BETA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
        sigma_syn=cfg['MODEL']['SIGMA_SYN'],
        use_bias=cfg['MODEL']['USE_BIAS'],
        anti_hebbian=cfg['MODEL']['ANTI_HEBB']).to(device)

    train_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    print(model)
    print('Epoch Loss Acc')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg['TRAIN']['LR'],
                           weight_decay=cfg['TRAIN']['WEIGHT_DECAY'])
    correct = 0
    num_data = 0
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden, new_j = model(inputs, hidden)
            # print(output)

            loss = torch.nn.CrossEntropyLoss()(output[:, -1], target)
            dummy_zero = torch.zeros([
                cfg['TRAIN']['BATCHSIZE'],
                cfg['DATALOADER']['TIME_LENGTH'] + 1, cfg['MODEL']['SIZE']
            ]).float().to(device)
            active_norm = torch.nn.MSELoss()(hidden_list, dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            optimizer.step()
            correct += (np.argmax(
                output[:, -1].cpu().detach().numpy(),
                axis=1) == target.cpu().detach().numpy()).sum().item()
            num_data += target.cpu().detach().numpy().shape[0]

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            acc = correct / num_data
            print(f'{epoch}, {loss.item():.6f}, {acc:.6f}')
            print(active_norm)
            # print('w_hh: ', model.w_hh.weight.cpu().detach().numpy()[:4, :4])
            # print('new_j: ', new_j.cpu().detach().numpy()[0, :4, :4])
            correct = 0
            num_data = 0
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    if 'CHECK_TIMING' not in cfg['DATALOADER']:
        cfg['DATALOADER']['CHECKTIMING'] = 5

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/static_input', exist_ok=True)
    save_path = f'trained_model/static_input/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path, os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    model = RecurrentNeuralNetwork(n_in=1, n_out=1, n_hid=cfg['MODEL']['SIZE'], device=device,
                                   alpha_time_scale=cfg['MODEL']['ALPHA'],
                                   activation=cfg['MODEL']['ACTIVATION'],
                                   sigma_neu=cfg['MODEL']['SIGMA_NEU'],
                                   use_bias=cfg['MODEL']['USE_BIAS']).to(device)

    train_dataset = StaticInput(time_length=cfg['DATALOADER']['TIME_LENGTH'],
                                time_scale=cfg['MODEL']['ALPHA'],
                                value_min=cfg['DATALOADER']['VALUE_MIN'],
                                value_max=cfg['DATALOADER']['VALUE_MAX'],
                                signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                                variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
                                sigma_in=cfg['DATALOADER']['SIGMA_IN'])

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg['TRAIN']['BATCHSIZE'],
                                                   num_workers=2, shuffle=True,
                                                   worker_init_fn=lambda x: np.random.seed())

    print(model)
    print('Epoch Loss')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                           lr=cfg['TRAIN']['LR'], weight_decay=cfg['TRAIN']['WEIGHT_DECAY'])
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.float()
            inputs, target = Variable(inputs).to(device), Variable(target).to(device)

            hidden_np = np.random.normal(0, 0.5, size=(cfg['TRAIN']['BATCHSIZE'], cfg['MODEL']['SIZE']))
            # hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'], cfg['MODEL']['SIZE'])
            if 'RANDOM_START' in cfg['DATALOADER'] and not cfg['DATALOADER']['RANDOM_START']:
                hidden_np = np.zeros((cfg['TRAIN']['BATCHSIZE'], cfg['MODEL']['SIZE']))
            hidden = torch.from_numpy(hidden_np).float()
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden = model(inputs, hidden)

            check_timing = np.random.randint(-cfg['DATALOADER']['CHECK_TIMING'], 0)
            loss = torch.nn.MSELoss()(output[:, check_timing], target)
            if 'FIXED_DURATION' in cfg['DATALOADER']:
                for j in range(1, cfg['DATALOADER']['FIXED_DURATION'] + 1):
                    loss += torch.nn.MSELoss()(output[:, check_timing - j], target)
            dummy_zero = torch.zeros([cfg['TRAIN']['BATCHSIZE'],
                                      cfg['DATALOADER']['TIME_LENGTH'] + 1,
                                      cfg['MODEL']['SIZE']]).float().to(device)
            active_norm = torch.nn.MSELoss()(hidden_list, dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            optimizer.step()

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            print(f'{epoch}, {loss.item():.4f}')
            print('output: ',
                  output[0, check_timing - cfg['DATALOADER']['FIXED_DURATION']: check_timing, 0].cpu().detach().numpy())
            print('target: ', target[0, 0].cpu().detach().numpy())
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(), os.path.join(save_path, f'epoch_{epoch}.pth'))
def main(config_path):
    # hyper-parameter
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    model_name = os.path.splitext(os.path.basename(config_path))[0]

    # save path
    os.makedirs('trained_model', exist_ok=True)
    os.makedirs('trained_model/freq_schedule2', exist_ok=True)
    save_path = f'trained_model/freq_schedule2/{model_name}'
    os.makedirs(save_path, exist_ok=True)

    # copy config file
    shutil.copyfile(config_path,
                    os.path.join(save_path, os.path.basename(config_path)))

    use_cuda = cfg['MACHINE']['CUDA'] and torch.cuda.is_available()
    torch.manual_seed(cfg['MACHINE']['SEED'])
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    if 'ALPHA' not in cfg['MODEL'].keys():
        cfg['MODEL']['ALPHA'] = 0.25

    model = RecurrentNeuralNetwork(
        n_in=1,
        n_out=2,
        n_hid=cfg['MODEL']['SIZE'],
        device=device,
        alpha_time_scale=cfg['MODEL']['ALPHA'],
        beta_time_scale=cfg['MODEL']['BETA'],
        activation=cfg['MODEL']['ACTIVATION'],
        sigma_neu=cfg['MODEL']['SIGMA_NEU'],
        sigma_syn=cfg['MODEL']['SIGMA_SYN'],
        use_bias=cfg['MODEL']['USE_BIAS'],
        anti_hebbian=cfg['MODEL']['ANTI_HEBB']).to(device)

    train_dataset = FreqDataset(
        time_length=cfg['DATALOADER']['TIME_LENGTH'],
        time_scale=cfg['MODEL']['ALPHA'],
        freq_min=cfg['DATALOADER']['FREQ_MIN'],
        freq_max=cfg['DATALOADER']['FREQ_MAX'],
        min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
        signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
        variable_signal_length=cfg['DATALOADER']['VARIABLE_SIGNAL_LENGTH'],
        sigma_in=cfg['DATALOADER']['SIGMA_IN'],
        delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg['TRAIN']['BATCHSIZE'],
        num_workers=2,
        shuffle=True,
        worker_init_fn=lambda x: np.random.seed())

    print(model)
    print('Epoch Loss Acc')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=cfg['TRAIN']['LR'],
                           weight_decay=cfg['TRAIN']['WEIGHT_DECAY'])
    correct = 0
    num_data = 0
    phase2 = False
    phase3 = False
    phase4 = False
    phase5 = False
    phase6 = False
    if 'PHASE_TRANSIT' in cfg['TRAIN'].keys():
        phase_transition_criteria = cfg['TRAIN']['PHASE_TRANSIT']
    else:
        phase_transition_criteria = [0.5, 0.45, 0.4, 0.3, 0.2]
    for epoch in range(cfg['TRAIN']['NUM_EPOCH'] + 1):
        model.train()
        for i, data in enumerate(train_dataloader):
            inputs, target = data
            # print(inputs.shape)
            inputs, target = inputs.float(), target.long()
            inputs, target = Variable(inputs).to(device), Variable(target).to(
                device)

            hidden = torch.zeros(cfg['TRAIN']['BATCHSIZE'],
                                 cfg['MODEL']['SIZE'])
            hidden = hidden.to(device)

            optimizer.zero_grad()
            hidden = hidden.detach()
            hidden_list, output, hidden, new_j = model(inputs, hidden)
            # print(output)

            loss = torch.nn.CrossEntropyLoss()(output[:, -1], target)
            dummy_zero = torch.zeros([
                cfg['TRAIN']['BATCHSIZE'],
                int(cfg['DATALOADER']['TIME_LENGTH'] -
                    2 * cfg['DATALOADER']['SIGNAL_LENGTH']),
                cfg['MODEL']['SIZE']
            ]).float().to(device)
            active_norm = torch.nn.MSELoss()(
                hidden_list[:, cfg['DATALOADER']['SIGNAL_LENGTH']:
                            cfg['DATALOADER']['TIME_LENGTH'] -
                            cfg['DATALOADER']['SIGNAL_LENGTH'], :], dummy_zero)

            loss += cfg['TRAIN']['ACTIVATION_LAMBDA'] * active_norm
            loss.backward()
            optimizer.step()
            correct += (np.argmax(
                output[:, -1].cpu().detach().numpy(),
                axis=1) == target.cpu().detach().numpy()).sum().item()
            num_data += target.cpu().detach().numpy().shape[0]

            if not phase2 and float(
                    loss.item()) < phase_transition_criteria[0]:
                cfg['MODEL']['ALPHA'] = 0.2
                cfg['DATALOADER']['TIME_LENGTH'] = 95
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 20
                cfg['DATALOADER']['VARIABLE_DELAY'] = 6

                print("phase2 start! cfg['MODEL']['ALPHA'] = 0.2")
                phase2 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase3 and float(
                    loss.item()) < phase_transition_criteria[1]:
                cfg['MODEL']['ALPHA'] = 0.175
                cfg['DATALOADER']['TIME_LENGTH'] = 110
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 22
                cfg['DATALOADER']['VARIABLE_DELAY'] = 7

                print("phase3 start! cfg['MODEL']['ALPHA'] = 0.175")
                phase3 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase4 and float(
                    loss.item()) < phase_transition_criteria[2]:
                cfg['MODEL']['ALPHA'] = 0.15
                cfg['DATALOADER']['TIME_LENGTH'] = 120
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 25
                cfg['DATALOADER']['VARIABLE_DELAY'] = 8

                print("phase4 start! cfg['MODEL']['ALPHA'] = 0.15")
                phase4 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase5 and float(
                    loss.item()) < phase_transition_criteria[3]:
                cfg['MODEL']['ALPHA'] = 0.10
                cfg['DATALOADER']['TIME_LENGTH'] = 160
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 30
                cfg['DATALOADER']['VARIABLE_DELAY'] = 10

                print("phase5 start! cfg['MODEL']['ALPHA'] = 0.1")
                phase5 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

            if not phase6 and float(
                    loss.item()) < phase_transition_criteria[4]:
                cfg['MODEL']['ALPHA'] = 0.075
                cfg['DATALOADER']['TIME_LENGTH'] = 267
                cfg['DATALOADER']['SIGNAL_LENGTH'] = 50
                cfg['DATALOADER']['VARIABLE_DELAY'] = 15

                print("phase6 start! cfg['MODEL']['ALPHA'] = 0.075")
                phase6 = True
                model.change_alpha(cfg['MODEL']['ALPHA'])
                train_dataset = FreqDataset(
                    time_length=cfg['DATALOADER']['TIME_LENGTH'],
                    time_scale=cfg['MODEL']['ALPHA'],
                    freq_min=cfg['DATALOADER']['FREQ_MIN'],
                    freq_max=cfg['DATALOADER']['FREQ_MAX'],
                    min_interval=cfg['DATALOADER']['MIN_INTERVAL'],
                    signal_length=cfg['DATALOADER']['SIGNAL_LENGTH'],
                    variable_signal_length=cfg['DATALOADER']
                    ['VARIABLE_SIGNAL_LENGTH'],
                    sigma_in=cfg['DATALOADER']['SIGMA_IN'],
                    delay_variable=cfg['DATALOADER']['VARIABLE_DELAY'])
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=cfg['TRAIN']['BATCHSIZE'],
                    num_workers=2,
                    shuffle=True,
                    worker_init_fn=lambda x: np.random.seed())
                break

        if epoch % cfg['TRAIN']['DISPLAY_EPOCH'] == 0:
            acc = correct / num_data
            print(f'{epoch}, {loss.item():.6f}, {acc:.6f}')
            print(f'activation norm: {active_norm.item():.4f}, time scale: , '
                  f'{model.alpha.detach().cpu().numpy()[0]:.3f}')
            correct = 0
            num_data = 0
        if epoch > 0 and epoch % cfg['TRAIN']['NUM_SAVE_EPOCH'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(save_path, f'epoch_{epoch}.pth'))