Example #1
0
def train_ccn(net, data, task, criterion, optimizer, cuda, mean, std):
    """Trains model for one epoch"""

    n = len(data)
    losses = utils.RunningAverage()
    error = utils.RunningAverage()

    for i in range(n):

        optimizer.zero_grad()

        X, A, targets, _, _, _, _ = data[i]
        A = A + torch.eye(A.shape[0])
        y = targets[task].view(1)

        if mean == 0:  #generated data
            y = y.type(torch.LongTensor)
        else:
            y = (y - mean) / (std + 10**-8)

        X.requires_grad = True
        A.requires_grad = True

        if cuda:
            X = X.cuda()
            A = A.cuda()
            y = y.cuda()

        output = net(X, A)

        if mean != 0:
            error.update(utils.evaluation(output, y).item())
        else:
            output = output.view(1, -1)

        #print(y.shape)
        #print(output.shape)
        loss = criterion(output, y)
        losses.update(loss.item())
        """
        if i % 400 == 0 :
            logging.info('After {} instances : Loss {:.3f} MAE {:.3f}'.format(i,losses.val,error.val)) 
        """

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

    return losses.val, error.val
Example #2
0
def main_ccn():
    
    global args
    args = parser.parse_args()
    if args.log_path == None:
        log_path = ('log/ccn/k_' + str(args.order) + '_ep_' + str(args.max_epoch) + '_st_' + str(args.epoch_step)
                    + '_op_' + str(args.optim) + '_lr_' + str(args.lr) + '_da_' + str(args.lrdamping)
                    + '_L_' + str(args.layers) + '_h_' + str(args.hidden_size) + '_ta_'
                    + str(args.task) + '_' + str(time.time())[-3:] + '.pickle'
        )
        args.log_path = log_path
    log.info("Log path : " + log_path)
    
    # logger
    logger = logs.Logger(args.log_path)
    logger.write_settings(args)
    
    # Check if CUDA is enabled
    if args.cuda== True and torch.cuda.is_available():
        log.info('Working on GPU')
        dtype = torch.cuda.FloatTensor
        #torch.cuda.manual_seed(0)
        
    else:
        log.info('Working on CPU')
        args.cuda = False
        dtype = torch.FloatTensor
        #torch.manual_seed(0)

    # load training, validation and test datasets
    if args.train==True:
        with open(args.train_path,'rb') as file :
            train_set = pickle.load(file)
            Ntrain = len(train_set)   
            log.info("Number of training instances : " + str(Ntrain))
            logger.add_info('Training set size : ' + str(Ntrain))
    if args.val==True:
        with open(args.valid_path,'rb') as file :
            valid_set = pickle.load(file)
            Nvalid = len(valid_set) 
            log.info("Number of validation instances : " + str(Nvalid))
            logger.add_info('Validation set size : ' + str(Nvalid))
    if args.test==True:
        with open(args.test_path,'rb') as file :
            test_set = pickle.load(file)
            Ntest = len(test_set) 
            log.info("Number of test instances : " + str(Ntest))
            logger.add_info('Test set size : ' + str(Ntest))
    
    dim_input = train_set[0][0].size()[1]
    logger.add_info('Number of features of the inputs : ' + str(dim_input))

    # Creates or loads model
    if args.train == False or args.model_path != None:
        ccn = torch.load(args.model_path)
        log.info('Network loaded')
    else:
        if args.order==1 :
            ccn = model_ccn.CCN_1D(dim_input, 1, args.hidden_size,args.layers, args.cuda)
            logger.add_model('first order CCN')
            log.info('First-order CCN created')
            
        elif args.order==2:
            ccn = model_ccn.CCN_2D(dim_input, 1, args.hidden_size, args.layers, args.cuda)
            logger.add_model('second order CCN')
            log.info('Second-order CCN created')
        else:
            log.info('Order not implemented yet, second-order CNN will be created')
            ccn = model_ccn.CCN_2D(dim_input, 1, args.hidden_size, args.layers, args.cuda)

    
    # Target stats
    stats_path = '/misc/vlgscratch4/BrunaGroup/sulem/chem/data/tensors/target_stat.pickle'
    with open(stats_path,'rb') as file :
        M, S, A = pickle.load(file)
    mean = M[args.task].item()
    std = S[args.task].item()
    accuracy = A[args.task].item()
    
    # Criterion and optimizer
    criterion = nn.MSELoss()
    
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(ccn.parameters(), lr=args.lr,
                                       momentum=args.momentum)
    elif args.optim == 'adamax':
        optimizer = torch.optim.Adamax(ccn.parameters(), lr=args.lr)
    
    else :
        optimizer = torch.optim.Adam(ccn.parameters(), lr=args.lr)
    
    if args.cuda == True :    
        ccn = ccn.cuda()
        criterion = criterion.cuda()
    
    # Training
    
    if args.train==True:
        ccn.train()
        
        log.info('Training the CCN...')
        logger.add_res('Training phase')
        
        run_loss = utils.RunningAverage()
        run_error = utils.RunningAverage()
        
        for epoch in range (args.max_epoch):
            
            t0 = time.time()
            
            if epoch != 0 and epoch % args.epoch_step == 0 :
                args.lr = args.lr * args.lrdamping
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr
            
            loss, error = train_ccn.train_ccn(ccn, train_set, args.task, criterion,
                                         optimizer, args.cuda, mean, std)
            
            dur = int(time.time() - t0)
            
            run_loss.update(loss)
            run_error.update(error)
            
            logger.add_epoch_info(epoch+1,run_loss.val, run_error.val, dur)
            log.info('Epoch {} : Avg Error {:.3f}; Average Loss {:.3f} Time : {}'
              .format(epoch+1, run_error.val, run_loss.val, dur))
        
        training_time = sum(logger.time_epoch)
        ratio = run_error.val / accuracy
        
        logger.add_train_info(run_loss.val, run_error.val, training_time,run_error.val)    
        log.info('Training finished : Duration {} secs, Avg Loss {:.3f}, Mean Average Error {:.3f}, Error ratio {:.3f}'
              .format(training_time, run_loss.val, run_error.val, ratio))
        
        logger.save_model(ccn)
    
    
    # Validating
    
    if args.val==True:
        log.info('Evaluating on the validation set...')
        logger.add_res('Validation phase')
        val_loss, val_error, dur = test_ccn.test_ccn(ccn, valid_set, args.task,
                                                       criterion, args.cuda,
                                                       mean,std, logger)
        ratio_val = val_error / accuracy
        log.info('Validation finished : Avg loss {:.3f}, Mean Average Error {:.3f}, Error ratio {:.3f}, Duration : {} seconds'
                 .format(val_loss, val_error, ratio_val, dur))
        logger.add_test_perf(val_loss, val_error, ratio_val)
        
        logger.plot_train_logs()
        logger.plot_test_logs()    
        
    
    if args.test==True:
        log.info('Evaluating on the test set...')
        logger.add_res('Test phase')
        test_loss, test_error, dur = test_ccn.test_ccn(ccn, test_set, args.task,
                                                       criterion, args.cuda, mean, std, logger)
        ratio_test = test_error / accuracy
        log.info('Validation finished : Avg loss {:.3f}, Mean Average Error {:.3f}, Error ratio {:.3f}, Duration : {} seconds'
                 .format(test_loss, test_error, ratio_test, dur))
        logger.add_test_perf(test_loss, test_error, ratio_test)
        
        logger.plot_train_logs()
        #logger.plot_test_logs()    
        
        return test_error, ratio_test
Example #3
0
def main():
    
    global args
    args = parser.parse_args()
    
    # Setting log path
    if args.log_path == None:
        log_path = ('log/qm9/lg_' + str(args.lg) + '_up_' + str(args.update) + '_gru_' + str(args.gru) + '_bs_' 
                    + str(args.batch_size) + '_ep_' + str(args.max_epoch) + '_st_' + str(args.epoch_step)
                    + '_op_' + str(args.optim) + '_lr_' + str(args.lr) + '_da_' + str(args.lrdamping)
                    + '_L_' + str(args.layers) + '_h_' + str(args.nfeatures) + '_ta_' + str(args.task)
                    + '_' + str(time.time())[-3:] + '.pickle'
        )
        args.log_path = log_path
    log.info("Log path : " + log_path)
    
    # Initializing logger
    logger = logs.Logger(args.log_path)
    logger.write_settings(args)
    
    # Check if CUDA is enabled
    if args.cuda== True and torch.cuda.is_available():
        log.info('Working on GPU')
        #torch.cuda.manual_seed(0)
        
    else:
        log.info('Working on CPU')
        args.cuda = False
        #torch.manual_seed(0)
        
    # Loading population statistics for the task
    stats_path = '/misc/vlgscratch4/BrunaGroup/sulem/chem/data/tensors/target_stat.pickle'
    with open(stats_path,'rb') as file :
        M, S, A = pickle.load(file)
    mean = M[args.task].item()
    std = S[args.task].item()
    accuracy = A[args.task].item()
    
    # Loading experiment sets
    logging.info("Loading data...")
    with open(args.data_path,'rb') as file :
            data_set = pickle.load(file)
            
    train_set, valid_set, test_set = loading.prepare_experiment_sets(data_set,
                                                                     args.shuffle)
    if args.train==True:
        Ntrain = len(train_set)   
        log.info("Number of training instances : " + str(Ntrain))
        logger.add_info('Training set size : ' + str(Ntrain))
    
    if args.val==True:
        Nvalid = len(valid_set) 
        log.info("Number of validation instances : " + str(Nvalid))
        logger.add_info('Validation set size : ' + str(Nvalid))
            
    if args.test==True:
        Ntest = len(test_set) 
        log.info("Number of test instances : " + str(Ntest))
        logger.add_info('Test set size : ' + str(Ntest))

    # Creating or loading model
    if args.model_path != None:
        gnn = torch.load(args.model_path)
        log.info('Network loaded')
    else:
        if args.lg == False :     
            gnn = model_mnb.GNN_simple(args.task, args.nfeatures, args.layers,
                                       args.dim_input, 1, args.J, args.gru)
            logger.add_model('gnn simple')
        else:
            gnn = model_mnb.GNN_lg(args.task, args.nfeatures, args.layers,
                                   args.dim_input, args.J, 1, args.update)
            logger.add_model('gnn with LG')
        log.info('Network created')
    
    # Criterion and optimizer
    criterion = nn.MSELoss()
        
    if args.cuda == True :    
        gnn = gnn.cuda()
        criterion = criterion.cuda()

    # Training
    
    if args.train==True:
        gnn.train()
        
        log.info('Training the GNN')
        logger.add_res('Training phase')
        
        run_loss = utils.RunningAverage()
        run_error = utils.RunningAverage()
        
        for epoch in range (args.max_epoch):
            
            t0 = time.time()
            
            optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)
            
            loss, error = train_mnb.train_with_mnb(gnn, train_set, args.task, criterion,
                                         optimizer, args.cuda, args.batch_size, mean, std)
            
            """
            v_loss, v_error = test_mnb.test_with_mnb(gnn, valid_set, args.task,
                                                     criterion, args.cuda, args.batch_size,
                                                     mean, std, logger)
            
            t_loss, t_error = test_mnb.test_with_mnb(gnn, test_set, args.task,
                                         criterion, args.cuda, args.batch_size,
                                         mean, std, logger)
            """
            dur = int(time.time() - t0)
            
            run_loss.update(loss)
            run_error.update(error)
            
            if epoch != 0 and epoch % args.epoch_step == 0 :
                args.lr = args.lr * args.lrdamping
            
            """"
            logger.add_epoch_logs(epoch+1,run_loss.val, run_error.val, v_loss,
                                  v_error, t_loss, t_error, dur)
            """
            log.info('Epoch {} : Train loss {:.3f} error {:.3f} Time : {}'
              .format(epoch+1, run_error.val, run_loss.val, dur))
            """
            log.info('Validation loss {:.3f} error {:.3f}'
              .format(v_loss, v_error))
            log.info('Test loss {:.3f} error {:.3f}'
              .format(t_loss, t_error))
            """
            
        training_time = sum(logger.time_epoch) // 60
        ratio = run_error.val / accuracy
        
        """
        v_loss = logger.loss_valid[-1]
        v_error = logger.error_valid[-1]
        t_loss = logger.loss_test[-1]
        t_error = logger.error_test[-1]
        """
        
        logger.add_train_info(run_loss.val, run_error.val, ratio, training_time)
        
        """
        logger.add_valid_perf(v_loss, v_error, v_error/accuracy)
        logger.add_test_perf(t_loss, t_error, t_error/accuracy)
        """
        log.info('Training finished : Duration {} minutes, Loss {:.3f}, MAE {:.3f}, Error ratio {:.3f}'
              .format(training_time, run_loss.val, run_error.val, ratio))
        
        """
        log.info('Validation loss {:.3f} error {:.3f}'.format(v_loss, v_error))
        log.info('Test loss {:.3f} error {:.3f}'.format(t_loss, t_error))
        
        logger.plot_loss()
        logger.plot_error()
        """
        
        logger.save_model(gnn)
        
    # Validating
    
    if args.val==True:
        log.info('Evaluating on the validation set...')
        logger.add_res('Validation phase')
        val_loss, val_error = test_mnb.test_with_mnb(gnn, valid_set, args.task,
                                                     criterion, args.cuda, args.batch_size,
                                                     mean, std, logger)
        ratio_val = val_error / accuracy
        log.info('Validation finished : Avg loss {:.3f}, Mean Average Error {:.3f}, Error ratio {:.3f}'
                 .format(val_loss, val_error, ratio_val))
        logger.add_test_perf(val_loss, val_error, ratio_val)
        
        logger.plot_train_logs()
        logger.plot_test_logs()    
        
    # Testing
    if args.test==True:
        log.info('Evaluating on the test set...')
        logger.add_res('Test phase')
        test_loss, test_error = test_mnb.test_with_mnb(gnn, test_set, args.task, criterion,
                                                       args.cuda, args.batch_size,
                                                       mean, std, logger)
        ratio_test = test_error / accuracy
        log.info('Test finished : Avg loss {:.3f}, Mean Average Error {:.3f}, Error ratio {:.3f}'
                 .format(test_loss, test_error, ratio_test))
        logger.add_test_perf(test_loss, test_error, ratio_test)
        
        logger.plot_train_logs()
        #logger.plot_test_logs() 
        
        return test_error, ratio_test
Example #4
0
def main_ccn():

    global args
    args = parser.parse_args()

    # setting logger path
    if args.log_path == None:
        log_path = ('log/qm9/ccn_k_' + str(args.order) + '_ep_' +
                    str(args.max_epoch) + '_st_' + str(args.epoch_step) +
                    '_op_' + str(args.optim) + '_lr_' + str(args.lr) + '_da_' +
                    str(args.lrdamping) + '_L_' + str(args.layers) + '_h_' +
                    str(args.hidden_size) + '_ta_' + str(args.task) + '_' +
                    str(time.time())[-3:] + '.pickle')
        args.log_path = log_path
    log.info("Log path : " + log_path)

    # initializing logger
    logger = logs.Logger(args.log_path)
    logger.write_settings(args)

    # Check if CUDA is enabled
    if args.cuda == True and torch.cuda.is_available():
        log.info('Working on GPU')
        #torch.cuda.manual_seed(0)

    else:
        log.info('Working on CPU')
        args.cuda = False
        #torch.manual_seed(0)

    # Loading experiment sets
    logging.info("Loading data...")
    with open(args.data_path, 'rb') as file:
        data_set = pickle.load(file)
    train_set, valid_set, test_set = loading.prepare_experiment_sets(
        data_set, args.shuffle)
    if args.train == True:
        Ntrain = len(train_set)
        log.info("Number of training instances : " + str(Ntrain))
        logger.add_info('Training set size : ' + str(Ntrain))

    if args.val == True:
        Nvalid = len(valid_set)
        log.info("Number of validation instances : " + str(Nvalid))
        logger.add_info('Validation set size : ' + str(Nvalid))

    if args.test == True:
        Ntest = len(test_set)
        log.info("Number of test instances : " + str(Ntest))
        logger.add_info('Test set size : ' + str(Ntest))

    # Creates or loads model
    if args.train == False or args.model_path != None:
        ccn = torch.load(args.model_path)
        log.info('Network loaded')
    else:
        if args.order == 1:
            ccn = model_ccn.CCN_1D(args.dim_input, 1, args.hidden_size,
                                   args.layers, args.cuda)
            logger.add_model('first order CCN')
            log.info('First-order CCN created')

        elif args.order == 2:
            ccn = model_ccn.CCN_2D(args.dim_input, 1, args.hidden_size,
                                   args.layers, args.cuda)
            logger.add_model('second order CCN')
            log.info('Second-order CCN created')
        else:
            log.info(
                'Order not implemented yet, second-order CNN will be created')
            ccn = model_ccn.CCN_2D(args.dim_input, 1, args.hidden_size,
                                   args.layers, args.cuda)

    # Target stats
    stats_path = '/misc/vlgscratch4/BrunaGroup/sulem/chem/data/tensors/target_stat.pickle'
    with open(stats_path, 'rb') as file:
        M, S, A = pickle.load(file)
    mean = M[args.task].item()
    std = S[args.task].item()
    accuracy = A[args.task].item()

    # Criterion and optimizer
    criterion = nn.MSELoss()

    if args.cuda == True:
        ccn = ccn.cuda()
        criterion = criterion.cuda()

    # Training

    if args.train == True:
        ccn.train()

        log.info('Training the CCN...')
        logger.add_res('Training phase')

        run_loss = utils.RunningAverage()
        run_error = utils.RunningAverage()

        for epoch in range(args.max_epoch):

            t0 = time.time()

            optimizer = torch.optim.Adamax(ccn.parameters(), lr=args.lr)

            loss, error = train_ccn.train_ccn(ccn, train_set, args.task,
                                              criterion, optimizer, args.cuda,
                                              mean, std)

            v_loss, v_error, _ = test_ccn.test_ccn(ccn, valid_set, args.task,
                                                   criterion, args.cuda, mean,
                                                   std, logger)

            t_loss, t_error, _ = test_ccn.test_ccn(ccn, test_set, args.task,
                                                   criterion, args.cuda, mean,
                                                   std, logger)

            dur = int(time.time() - t0)

            run_loss.update(loss)
            run_error.update(error)

            if epoch != 0 and epoch % args.epoch_step == 0:
                args.lr = args.lr * args.lrdamping

            logger.add_epoch_logs(epoch + 1, run_loss.val, run_error.val,
                                  v_loss, v_error, t_loss, t_error, dur)
            log.info(
                'Epoch {} : Train loss {:.3f} error {:.3f} Time : {}'.format(
                    epoch + 1, run_error.val, run_loss.val, dur))
            log.info('Validation loss {:.3f} error {:.3f}'.format(
                v_loss, v_error))
            log.info('Test loss {:.3f} error {:.3f}'.format(t_loss, t_error))

        training_time = sum(logger.time_epoch)
        ratio = run_error.val / accuracy

        logger.add_train_info(run_loss.val, run_error.val, ratio,
                              training_time)
        logger.add_valid_perf(v_loss, v_error, v_error / accuracy)
        logger.add_test_perf(t_loss, t_error, t_error / accuracy)
        log.info(
            'Training finished : Duration {} minutes, Loss {:.3f}, MAE {:.3f}, Error ratio {:.3f}'
            .format(training_time, run_loss.val, run_error.val, ratio))
        log.info('Validation loss {:.3f} error {:.3f}'.format(v_loss, v_error))
        log.info('Test loss {:.3f} error {:.3f}'.format(t_loss, t_error))

        logger.plot_loss()
        logger.plot_error()

        logger.save_model(ccn)
    """    
Example #5
0
def train_with_mnb(model, data, task, criterion, optimizer, cuda, bs, mean, std):
    """Trains model for one epoch using minibatches"""
    
    n = len(data)
    dual = model.dual
    J = model.J
    #logging.warning('Training on {} molecules'.format(n))
    
    losses = utils.RunningAverage()
    error = utils.RunningAverage()

    model.train()    
    
    batch_idx = batching.get_batches(n,bs,data,False,False)
    #print(batch_idx)
    
    for i, b in enumerate(batch_idx):
        
        optimizer.zero_grad()
        
        batch = [data[j] for j in b]
        bsi = len(batch)
        
        X, W, T, XL, WL, Pm, Pd, mask, mask_lg, N_batch, E_batch = batching.prepare_batch(batch, task, J)
        
        if mean==0: #generated data
            T = (T.squeeze()).type(torch.LongTensor)
        else:
            T = utils.normalize_data(T, mean, std)
        #print("Batch {} prepared".format(i+1))
        
        X.requires_grad = True
        W.requires_grad = True
        
        if cuda == True:
            X, W, T, XL, WL, Pm, Pd, mask, mask_lg, N_batch, E_batch =  X.cuda(), W.cuda(), T.cuda(), XL.cuda(), WL.cuda(), Pm.cuda(), Pd.cuda(), mask.cuda() , mask_lg.cuda(), N_batch.cuda(), E_batch.cuda()  
        
        if dual == False:
            XL.requires_grad = True
            WL.requires_grad = True
            Pm.requires_grad = True
            Pd.requires_grad = True
            output = model([X, W], N_batch, mask)
                
        else :
            output = model([X, XL, W, WL, Pm, Pd], N_batch, mask, E_batch, mask_lg)
            
        #print(output.shape)
        #print(output, T)
        #print(T.shape)

        train_loss = criterion(output, T)
        #logging.info("Training loss : {:.3f}".format(train_loss))
        # Logs
        losses.update(train_loss.item())
        
        if mean!=0:
            error.update(utils.evaluation(output, T).item())
        
        """
        if i % 20 == 0 :
            logging.info('Batch {} : Loss {:.3f} MAE {:.3f}'.format(i+1,losses.val,error.val)) 
        """
        
        # compute gradient and do SGD step
        train_loss.backward()
        optimizer.step()
        
    
    return losses.val, error.val
Example #6
0
def main():

    global args
    args = parser.parse_args()
    if args.log_path == None:
        log_path = ('log/simul_data/lg_' + str(args.lg) + '_up_' +
                    str(args.update) + '_bs_' + str(args.batch_size) + '_ep_' +
                    str(args.max_epoch) + '_st_' + str(args.epoch_step) +
                    '_op_' + str(args.optim) + '_lr_' + str(args.lr) + '_da_' +
                    str(args.lrdamping) + '_L_' + str(args.layers) + '_h_' +
                    str(args.nfeatures) + '_ta_' + str(args.task) + '_' +
                    str(time.time())[-3:] + '.pickle')
        args.log_path = log_path
    log.info("Log path : " + log_path)

    # logger
    logger = logs.Logger(args.log_path)
    logger.write_settings(args)

    # Check if CUDA is enabled
    if args.cuda == True and torch.cuda.is_available():
        log.info('Working on GPU')
        #torch.cuda.manual_seed(0)

    else:
        log.info('Working on CPU')
        args.cuda = False
        #torch.manual_seed(0)

    # load training, validation and test datasets
    if args.train == True:
        with open(args.train_path, 'rb') as file:
            train_set = pickle.load(file)
            Ntrain = len(train_set)
            log.info("Number of training instances : " + str(Ntrain))
            logger.add_info('Training set size : ' + str(Ntrain))
    if args.val == True:
        with open(args.valid_path, 'rb') as file:
            valid_set = pickle.load(file)
            Nvalid = len(valid_set)
            log.info("Number of validation instances : " + str(Nvalid))
            logger.add_info('Validation set size : ' + str(Nvalid))
    if args.test == True:
        with open(args.test_path, 'rb') as file:
            test_set = pickle.load(file)
            Ntest = len(test_set)
            log.info("Number of test instances : " + str(Ntest))
            logger.add_info('Test set size : ' + str(Ntest))

    dim_input = train_set[0][0].size()[1]
    logger.add_info('Number of features of the inputs : ' + str(dim_input))

    # Creates or loads model
    if args.train == False or args.model_path != None:
        gnn = torch.load(args.model_path)
        log.info('Network loaded')
    else:
        if args.lg == False:
            gnn = model_mnb.GNN_simple(args.task, args.nfeatures, args.layers,
                                       dim_input, 2, args.J)
            logger.add_model('gnn simple')
        else:
            gnn = model_mnb.GNN_lg(args.task, args.nfeatures, args.layers,
                                   dim_input, 2, args.J, args.update)
            logger.add_model('gnn with LG')
        log.info('Network created')

    # Criterion and optimizer
    criterion = nn.CrossEntropyLoss()

    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(gnn.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optim == 'adamax':
        optimizer = torch.optim.Adamax(gnn.parameters(), lr=args.lr)

    else:
        optimizer = torch.optim.Adam(gnn.parameters(), lr=args.lr)

    if args.cuda == True:
        gnn = gnn.cuda()

    # Training

    if args.train == True:
        gnn.train()

        log.info('Training the GNN')
        logger.add_res('Training phase')

        run_loss = utils.RunningAverage()
        #run_error = utils.RunningAverage()

        for epoch in range(args.max_epoch):

            t0 = time.time()

            if epoch != 0 and epoch % args.epoch_step == 0:
                args.lr = args.lr * args.lrdamping
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr

            loss, _ = train_mnb.train_with_mnb(gnn, train_set, args.task,
                                               criterion, optimizer, args.cuda,
                                               args.batch_size, 0, 1)

            dur = int(time.time() - t0)

            run_loss.update(loss)
            #run_error.update(error)

            logger.add_epoch_info(epoch + 1, run_loss.val, run_loss.val, dur)
            log.info('Epoch {} : Average Loss {:.3f} Time : {}'.format(
                epoch + 1, run_loss.val, dur))

        training_time = sum(logger.time_epoch)
        #ratio = run_error.val

        logger.add_train_info(run_loss.val, run_loss.val, training_time,
                              run_loss.val)
        log.info(
            'Training finished : Duration {} secs, Avg Loss {:.3f}'.format(
                training_time, run_loss.val))

        logger.save_model(gnn)

    # Validating

    if args.val == True:
        log.info('Evaluating on the validation set...')
        logger.add_res('Validation phase')
        val_loss, _ = test_mnb.test_with_mnb(gnn, valid_set, args.task,
                                             criterion, args.cuda,
                                             args.batch_size, 0, 1, logger)
        #ratio_val = val_error
        log.info('Validation finished : Avg loss {:.3f}'.format(val_loss))
        logger.add_test_perf(val_loss, val_loss, val_loss)

        logger.plot_train_logs()
        #logger.plot_test_logs()

    if args.test == True:
        log.info('Evaluating on the test set...')
        logger.add_res('Test phase')
        test_loss, _ = test_mnb.test_with_mnb(gnn, test_set, args.task,
                                              criterion, args.cuda,
                                              args.batch_size, 0, 1, logger)
        #ratio_test = test_error
        log.info('Test finished : Avg loss {:.3f}'.format(test_loss))
        logger.add_test_perf(test_loss, test_loss, test_loss)

        logger.plot_train_logs()
        #logger.plot_test_logs()

        return test_loss