Esempio n. 1
0
def get_model(state, args, init_model_name=None):
    if init_model_name is not None and os.path.exists(init_model_name):
        model, optimizer, state = load_model(init_model_name,
                                             return_optimizer=True,
                                             return_state=True)
    else:
        if "conv_dropout" in args:
            conv_dropout = args.conv_dropout
        else:
            conv_dropout = cfg.conv_dropout
        cnn_args = {1}

        if args.fixed_segment is not None:
            frames = cfg.frames
        else:
            frames = None

        nb_layers = 4
        cnn_kwargs = {
            "activation": cfg.activation,
            "conv_dropout": conv_dropout,
            "batch_norm": cfg.batch_norm,
            "kernel_size": nb_layers * [3],
            "padding": nb_layers * [1],
            "stride": nb_layers * [1],
            "nb_filters": [16, 16, 32, 65],
            "pooling": [(2, 2), (2, 2), (1, 4), (1, 2)],
            "aggregation": args.agg_time,
            "norm_out": args.norm_embed,
            "frames": frames,
        }
        nb_frames_staying = cfg.frames // (2**2)
        model = CNN(*cnn_args, **cnn_kwargs)
        # model.apply(weights_init)
        state.update({
            'model': {
                "name": model.__class__.__name__,
                'args': cnn_args,
                "kwargs": cnn_kwargs,
                'state_dict': model.state_dict()
            },
            'nb_frames_staying': nb_frames_staying
        })
        if init_model_name is not None:
            save_model(state, init_model_name)
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    LOG.info(
        "number of parameters in the model: {}".format(pytorch_total_params))
    return model, state
Esempio n. 2
0
class WCRNN(nn.Module):  #, BaseFairseqModel):
    def __init__(self,
                 w2v_cfg,
                 n_in_channel,
                 nclass,
                 attention=False,
                 activation="Relu",
                 dropout=0,
                 train_cnn=True,
                 rnn_type='BGRU',
                 n_RNN_cell=64,
                 n_layers_RNN=1,
                 dropout_recurrent=0,
                 cnn_integration=False,
                 **kwargs):
        super(WCRNN, self).__init__()

        self.w2v = w2v_encoder(w2v_cfg)  #Wav2Vec2Config)
        #self.w2v = Wav2VecEncoder(Wav2Vec2SedConfig, None)
        self.pooling = nn.Sequential(nn.MaxPool2d((1, 4), (1, 4)))

        self.n_in_channel = n_in_channel
        self.attention = attention
        self.cnn_integration = cnn_integration
        n_in_cnn = n_in_channel
        if cnn_integration:
            n_in_cnn = 1
        self.cnn = CNN(n_in_cnn, activation, dropout, **kwargs)
        if not train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False
        self.train_cnn = train_cnn
        if rnn_type == 'BGRU':
            nb_in = self.cnn.nb_filters[-1]
            if self.cnn_integration:
                # self.fc = nn.Linear(nb_in * n_in_channel, nb_in)
                nb_in = nb_in * n_in_channel
            self.rnn = BidirectionalGRU(nb_in,
                                        n_RNN_cell,
                                        dropout=dropout_recurrent,
                                        num_layers=n_layers_RNN)
        else:
            NotImplementedError("Only BGRU supported for CRNN for now")
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(n_RNN_cell * 2, nclass)
        self.sigmoid = nn.Sigmoid()
        if self.attention:
            self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
            self.softmax = nn.Softmax(dim=-1)

    def load_cnn(self, state_dict):
        self.cnn.load_state_dict(state_dict)
        if not self.train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False

    def load_state_dict(self, state_dict, strict=True):
        self.w2v.load_state_dice(state_dict["w2v"])
        self.cnn.load_state_dict(state_dict["cnn"])
        self.rnn.load_state_dict(state_dict["rnn"])
        self.dense.load_state_dict(state_dict["dense"])

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state_dict = {
            "w2v":
            self.w2v.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "cnn":
            self.cnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "rnn":
            self.rnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            'dense':
            self.dense.state_dict(destination=destination,
                                  prefix=prefix,
                                  keep_vars=keep_vars)
        }
        return state_dict

    def save(self, filename):
        parameters = {
            'w2v': self.w2v.state_dict(),
            'cnn': self.cnn.state_dict(),
            'rnn': self.rnn.state_dict(),
            'dense': self.dense.state_dict()
        }
        torch.save(parameters, filename)

    def forward(self, audio):
        x = audio.squeeze()
        import pdb
        pdb.set_trace()
        feature = self.w2v(x)
        x = feature['x']
        x = x.transpose(1, 0)
        x = x.unsqueeze(1)

        # input size : (batch_size, n_channels, n_frames, n_freq)
        if self.cnn_integration:
            bs_in, nc_in = x.size(0), x.size(1)
            x = x.view(bs_in * nc_in, 1, *x.shape[2:])

        # conv features
        before = x
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if self.cnn_integration:
            x = x.reshape(bs_in, chan * nc_in, frames, freq)

        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        x = self.rnn(x)
        x = self.dropout(x)
        strong = self.dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if self.attention:
            sof = self.dense_softmax(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / (sof.sum(1) + 1e-08)  # [bs, nclass]
        else:
            weak = strong.mean(1)
        return strong, weak
Esempio n. 3
0
# prepare test loader for the test set
test_file = args.data_path + args.test_file
test_data = ArticlesDataset(csv_file=test_file, vocab=vocab, label2id=label2id, max_text_len=args.text_len)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

scores_dict = {'f1': [], 'recall': [], 'precision': [], 'confidence': []}

for run_num in range(args.num_runs):
    model_run_name = model_name + "_run"+str(run_num+1)
    print("-"*10, "Run", run_num+1, "-"*10)
    print("Model name:", model_run_name)
    print("Loading model from", save_path + model_run_name + ".pt")

    best_model = CNN(cnn_args=cnn_args, mlp_args=mlp_args).to(device)
    optimizer = torch.optim.Adam(best_model.parameters(), lr=0.005)
    load_checkpoint(save_path + model_run_name + ".pt", best_model, optimizer, device, log_file)

    results = evaluate(best_model, test_loader)
    scores_dict['f1'].append(results['f1'])
    scores_dict['recall'].append(results['recall'])
    scores_dict['precision'].append(results['precision'])

    # if args.save_confidence is True:
    #     scores_dict['confidence'].append(results['confidence'])
    #     scores_dict['labels'].append(results['labels'])
    #     scores_dict['content'].append(results['content'])
    #     sentence_encodings = results['sentence_encodings']


scores_filename = save_path + model_name + "_test_scores.json"
Esempio n. 4
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = '' + args.model_name

    if args.loc_gauss or args.loc_inv_q or args.loc_att:
        args.loc_info = True

    if args.att_gauss_abnormal or args.att_inv_q_abnormal or args.att_gauss_spatial or args.att_inv_q_spatial or \
            args.att_module or args.laplace_att:
        args.self_att = True

    print(args)

    with open('experiment_log_' + args.operator + '.txt', 'a') as f:
        print(args, file=f)

    # IMPORT MODEL======================================================================================================
    from models.CNN import CNN as Model

    # START KFOLDS======================================================================================================
    print('\nSTART KFOLDS CROSS VALIDATION\n')
#    print(f'{args.kfold_test} Test-Train folds each has {args.epochs} epochs for a {1.0/args.kfold_val}/{(args.kfold_val - 1.0)/args.kfold_val} Valid-Train split\n')

    train_folds, test_folds = kfold_indices_warwick(args.dataset_size, args.kfold_test, seed=args.seed)

    train_error_folds = []
    test_error_folds = []
    for current_fold in range(1, args.kfold_test + 1):


        print('#################### Train-Test fold: {}/{} ####################'.format(current_fold, args.kfold_test))

        # DIRECTORY FOR SAVING==========================================================================================
        snapshots_path = 'snapshots/'
        dir = snapshots_path + model_name + '_' + args.model_signature + '/'
        sw = SummaryWriter(f'tensorboard/{model_name}_{args.model_signature}_fold_{current_fold}')

        if not os.path.exists(dir):
            os.makedirs(dir)

        # LOAD DATA=====================================================================================================
        train_fold, val_fold = kfold_indices_warwick(len(train_folds[current_fold - 1]), args.kfold_val, seed=args.seed)
        train_fold = [train_folds[current_fold - 1][i] for i in train_fold]
        val_fold = [train_folds[current_fold - 1][i] for i in val_fold]
        loc = True if args.loc_info or args.out_loc else False
        train_set, val_set, test_set = load_breast(train_fold[0], val_fold[0], test_folds[current_fold - 1], loc)

        # CREATE MODEL==================================================================================================
        print('\tcreate models')
        model = Model(args)
        if args.cuda:
            model.cuda()

        # INIT OPTIMIZER================================================================================================
        print('\tinit optimizer')
        if args.optimizer == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.reg)
        elif args.optimizer == 'SGD':
            optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.reg, momentum=0.9)
        else:
            raise Exception('Wrong name of the optimizer!')

        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)

        # PERFORM EXPERIMENT============================================================================================
        print('\tperform experiment\n')

        train_error, test_error = experiment(
            args,
            kwargs,
            current_fold,
            train_set,
            val_set,
            test_set,
            model,
            optimizer,
            scheduler,
            dir,
            sw,
        )

        # APPEND FOLD RESULTS===========================================================================================
        train_error_folds.append(train_error)
        test_error_folds.append(test_error)

        with open('final_results_' + args.operator + '.txt', 'a') as f:
            print('RESULT FOR A SINGLE FOLD\n'
                  'SEED: {}\n'
                  'OPERATOR: {}\n'
                  'FOLD: {}\n'
                  'ERROR (TRAIN): {}\n'
                  'ERROR (TEST): {}\n\n'.format(args.seed, args.operator, current_fold, train_error, test_error),
                  file=f)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('experiment_log_' + args.operator + '.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)

    return np.mean(train_error_folds), np.std(train_error_folds), np.mean(test_error_folds), np.std(test_error_folds)
Esempio n. 5
0
class NEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an NEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.device = device
        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.N = args.N
        # Transition queue and replay memory
        self.transition_queue = []
        self.replay_every = args.replay_every
        self.replay_buffer_size = args.replay_buffer_size
        self.replay_memory = ReplayMemory(self.replay_buffer_size)
        # CNN for state embedding network
        self.frames_to_stack = args.frames_to_stack
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width
        self.cnn = CNN(self.frames_to_stack, self.embedding_size,
                       self.in_height, self.in_width).to(self.device)
        # Differentiable Neural Dictionary (DND): one for each action
        self.kernel = inverse_distance
        self.num_neighbors = args.num_neighbors
        self.max_memory = args.max_memory
        self.lr = args.lr
        self.dnd_list = []
        for i in range(env.action_space.n):
            self.dnd_list.append(
                DND(self.kernel, self.num_neighbors, self.max_memory,
                    args.optimizer, self.lr))
        # Optimizer for state embedding CNN
        self.q_lr = args.q_lr
        self.batch_size = args.batch_size
        self.optimizer = get_optimizer(args.optimizer, self.cnn.parameters(),
                                       self.lr)

    def choose_action(self, state_embedding):
        """
        Choose epsilon-greedy policy according to Q-estimates from DNDs
        """
        if random.uniform(0, 1) < self.epsilon:
            return random.randint(0, self.env.action_space.n - 1)
        else:
            qs = [dnd.lookup(state_embedding) for dnd in self.dnd_list]
            action = torch.argmax(torch.cat(qs))
            return action

    def Q_lookahead(self, t, warmup=False):
        """
        Return the N-step Q-value lookahead from time t in the transition queue
        """
        if warmup or len(self.transition_queue) <= t + self.N:
            lookahead = [tr.reward for tr in self.transition_queue[t:]]
            discounted = discount(lookahead, self.gamma)
            Q_N = torch.tensor([discounted], requires_grad=True)
            return Q_N
        else:
            lookahead = [
                tr.reward for tr in self.transition_queue[t:t + self.N]
            ]
            discounted = discount(lookahead, self.gamma)
            state = self.transition_queue[t + self.N].state
            state = torch.tensor(state).permute(2, 0,
                                                1).unsqueeze(0)  # (N,C,H,W)
            state = state.to(self.device)
            state_embedding = self.cnn(state)
            Q_a = [dnd.lookup(state_embedding) for dnd in self.dnd_list]
            maxQ = torch.cat(Q_a).max()
            Q_N = discounted + (self.gamma**self.N) * maxQ
            Q_N = torch.tensor([Q_N], requires_grad=True)
            return Q_N

    def Q_update(self, Q, Q_N):
        """
        Return the Q-update for DND updates
        """
        return Q + self.q_lr * (Q_N - Q)

    def update(self):
        """
        Iterate through the transition queue and make NEC updates
        """
        # Insert transitions into DNDs
        for t in range(len(self.transition_queue)):
            tr = self.transition_queue[t]
            action = tr.action
            tr = self.transition_queue[t]
            state = torch.tensor(tr.state).permute(2, 0, 1)  # (C,H,W)
            state = state.unsqueeze(0).to(self.device)  # (N,C,H,W)
            state_embedding = self.cnn(state)
            dnd = self.dnd_list[action]

            Q_N = self.Q_lookahead(t).to(self.device)
            embedding_index = dnd.get_index(state_embedding)
            if embedding_index is None:
                dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0))
            else:
                Q = self.Q_update(dnd.values[embedding_index], Q_N)
                dnd.update(Q.detach(), embedding_index)
            Q_N = Q_N.detach().to(self.device)
            self.replay_memory.push(tr.state, action, Q_N)
        # Commit inserts
        for dnd in self.dnd_list:
            dnd.commit_insert()
        # Train CNN on minibatch
        for t in range(len(self.transition_queue)):
            if t % self.replay_every == 0 or t == len(
                    self.transition_queue) - 1:
                # Train on random mini-batch from self.replay_memory
                batch = self.replay_memory.sample(self.batch_size)
                actual_Qs = torch.cat([sample.Q_N for sample in batch])
                predicted_Qs = []
                for sample in batch:
                    state = torch.tensor(sample.state).permute(2, 0,
                                                               1)  # (C,H,W)
                    state = state.unsqueeze(0).to(self.device)  # (N,C,H,W)
                    state_embedding = self.cnn(state)
                    dnd = self.dnd_list[sample.action]
                    predicted_Q = dnd.lookup(state_embedding, update_flag=True)
                    predicted_Qs.append(predicted_Q)
                predicted_Qs = torch.cat(predicted_Qs).to(self.device)
                loss = torch.dist(actual_Qs, predicted_Qs)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                for dnd in self.dnd_list:
                    dnd.update_params()

        # Clear out transition queue
        self.transition_queue = []

    def run_episode(self):
        """
        Train an NEC agent for a single episode:
            Interact with environment
            Append (state, action, reward) transitions to transition queue
            Call update at the end of the episode
        """
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        total_steps = 0
        total_reward = 0
        total_frames = 0
        done = False
        while not done:
            state_embedding = torch.tensor(state).permute(2, 0, 1)  # (C,H,W)
            state_embedding = state_embedding.unsqueeze(0).to(self.device)
            state_embedding = self.cnn(state_embedding)
            action = self.choose_action(state_embedding)
            next_state, reward, done, _ = self.env.step(action)
            self.transition_queue.append(Transition(state, action, reward))
            total_reward += reward
            total_frames += self.env.skip
            total_steps += 1
            state = next_state
        self.update()
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, total_frames, total_reward
        else:
            return total_frames, total_reward

    def warmup(self):
        """
        Warmup the DND with values from an episode with a random policy
        """
        state = self.env.reset()
        total_reward = 0
        total_frames = 0
        done = False
        while not done:
            action = random.randint(0, self.env.action_space.n - 1)
            next_state, reward, done, _ = self.env.step(action)
            total_reward += reward
            total_frames += self.env.skip
            self.transition_queue.append(Transition(state, action, reward))
            state = next_state

        for t in range(len(self.transition_queue)):
            tr = self.transition_queue[t]
            state_embedding = torch.tensor(tr.state).permute(2, 0,
                                                             1)  # (C,H,W)
            state_embedding = state_embedding.unsqueeze(0).to(self.device)
            state_embedding = self.cnn(state_embedding)
            action = tr.action
            dnd = self.dnd_list[action]

            Q_N = self.Q_lookahead(t, True).to(self.device)
            if dnd.keys_to_be_inserted is None and dnd.keys is None:
                dnd.insert(state_embedding, Q_N.detach().unsqueeze(0))
            else:
                embedding_index = dnd.get_index(state_embedding)
                if embedding_index is None:
                    state_embedding = state_embedding.detach()
                    dnd.insert(state_embedding, Q_N.detach().unsqueeze(0))
                else:
                    Q = self.Q_update(dnd.values[embedding_index], Q_N)
                    dnd.update(Q.detach(), embedding_index)
            self.replay_memory.push(tr.state, action, Q_N.detach())
        for dnd in self.dnd_list:
            dnd.commit_insert()
        # Clear out transition queue
        self.transition_queue = []
        return total_frames, total_reward
Esempio n. 6
0
    emb_name = args.pretrained_emb.split("/")[-1][:-4]
    model_name = "cnn_" + str(args.emb_dim) + "embDim_" + \
                 str(args.num_kernel) + "kernels_" + \
                 str(args.stride) + "stride_" + \
                 str(args.mlp_hidden_size) + "MLPhidden_" + \
                 str(args.num_epochs) + "epochs" + "_" + emb_name
    model_name += "_run" + str(run_num + 1)

    print("Model name:", model_name)

    log_file = open(save_path + model_name + "_logs.txt", 'w')
    model = CNN(cnn_args=cnn_args, mlp_args=mlp_args).to(device)
    print(model)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.BCELoss()

    train(model=model,
          optimizer=optimizer,
          num_epochs=num_epochs,
          criterion=loss_fn,
          eval_every=args.step_size,
          train_loader=train_loader,
          valid_loader=valid_loader,
          save_path=save_path,
          model_name=model_name)

    print("Done training! Best model saved at", save_path + model_name + ".pt")
    log_file.close()
Esempio n. 7
0
class CRNN(nn.Module):
    def __init__(self,
                 n_in_channel,
                 nclass,
                 attention=False,
                 activation="Relu",
                 dropout=0,
                 train_cnn=True,
                 rnn_type='BGRU',
                 n_RNN_cell=64,
                 n_layers_RNN=1,
                 dropout_recurrent=0,
                 **kwargs):
        super(CRNN, self).__init__()
        self.attention = attention
        self.cnn = CNN(n_in_channel, activation, dropout, **kwargs)
        if not train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False
        self.train_cnn = train_cnn
        if rnn_type == 'BGRU':
            self.rnn = BidirectionalGRU(self.cnn.nb_filters[-1],
                                        n_RNN_cell,
                                        dropout=dropout_recurrent,
                                        num_layers=n_layers_RNN)
        else:
            NotImplementedError("Only BGRU supported for CRNN for now")
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(n_RNN_cell * 2, nclass)
        self.sigmoid = nn.Sigmoid()
        if self.attention:
            self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
            self.softmax = nn.Softmax(dim=-1)

    def load_cnn(self, parameters):
        self.cnn.load(parameters)
        if not self.train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False

    def load(self, filename=None, parameters=None):
        if filename is not None:
            parameters = torch.load(filename)
        if parameters is None:
            raise NotImplementedError(
                "load is a filename or a list of parameters (state_dict)")

        self.cnn.load(parameters=parameters["cnn"])
        self.rnn.load_state_dict(parameters["rnn"])
        self.dense.load_state_dict(parameters["dense"])

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state_dict = {
            "cnn":
            self.cnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "rnn":
            self.rnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            'dense':
            self.dense.state_dict(destination=destination,
                                  prefix=prefix,
                                  keep_vars=keep_vars)
        }
        return state_dict

    def save(self, filename):
        parameters = {
            'cnn': self.cnn.state_dict(),
            'rnn': self.rnn.state_dict(),
            'dense': self.dense.state_dict()
        }
        torch.save(parameters, filename)

    def forward(self, x):
        # input size : (batch_size, n_channels, n_frames, n_freq)
        # conv features
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if freq != 1:
            warnings.warn("Output shape is: {}".format(
                (bs, frames, chan * freq)))
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        x = self.rnn(x)
        x = self.dropout(x)
        strong = self.dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if self.attention:
            sof = self.dense_softmax(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / sof.sum(1)  # [bs, nclass]
        else:
            weak = strong.mean(1)
        return strong, weak
Esempio n. 8
0
def main():
    songs = get_notes()

    vocab_set = set()
    for song in songs:
        for note in song:
            vocab_set.add(note)

    n_in, n_out = prep_sequences(songs, sequence_length=100)
    X_train, X_val, y_train, y_val = train_test_split(n_in,
                                                      n_out,
                                                      test_size=0.2)

    train_ds = MusicDataset(X_train, y_train)
    val_ds = MusicDataset(X_val, y_val)

    train_dataloader = DataLoader(train_ds,
                                  batch_size=512,
                                  shuffle=True,
                                  num_workers=0)
    val_dataloader = DataLoader(val_ds,
                                batch_size=512,
                                shuffle=False,
                                num_workers=0)

    model = CNN(100, len(vocab_set))
    model.cuda()
    epochs = 25
    initial_lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = CrossEntropyLoss()

    train_losses = []
    val_losses = []

    train_accuracies = []
    val_accuracies = []

    for epoch in tqdm(range(1, epochs + 1)):

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        correct = 0
        ### Train
        for i, batch in enumerate(train_dataloader):
            X, y = batch[0].cuda(), batch[1].cuda()
            train_preds = model(X)

            loss = loss_fn(train_preds, y)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            train_preds = torch.max(train_preds, 1)[1]
            correct += (train_preds == y).float().sum()

        train_loss_total_avg = train_loss_total / num_steps
        train_accuracy = correct / len(train_ds)
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss_total_avg)

        model.eval()
        val_loss_total = 0.0
        num_steps = 0
        correct = 0
        for i, batch in enumerate(val_dataloader):
            with torch.no_grad():
                X, y = batch[0].cuda(), batch[1].cuda()

                val_preds = model(X)
                loss = loss_fn(val_preds, y)
                val_loss_total += loss.item()
                val_preds = torch.max(val_preds, 1)[1]
                correct += (val_preds == y).float().sum()

            num_steps += 1

        val_loss_total_avg = val_loss_total / num_steps
        val_accuracy = correct / len(val_ds)
        val_accuracies.append(val_accuracy)
        val_losses.append(val_loss_total_avg)

        scheduler.step()
        print('\nTrain loss: {:.4f}'.format(train_loss_total_avg))
        print('Train accuracy: {:.4f}'.format(train_accuracy))

        print('Val loss: {:.4f}'.format(val_loss_total_avg))
        print('Val accuracy\n: {:.4f}'.format(val_accuracy))

        torch.save(model.state_dict(),
                   "weights/model_params_epoch" + str(epoch))
        torch.save(optimizer.state_dict(),
                   "weights/optim_params_epoch" + str(epoch))

    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies)
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies)
    plt.savefig("plots/accuracies.png")
    plt.close()

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.plot(range(1, len(train_losses) + 1), train_losses)
    plt.plot(range(1, len(val_losses) + 1), val_losses)
    plt.savefig("plots/losses.png")
    plt.close()

    generate_midi(model, val_ds, vocab_set, output_filename="output.mid")