def evaluate(i_epoch):
        """ Evaluated model on test set """
        model.eval()

        with torch.no_grad():

            loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, collate_fn=graph_collate, num_workers=args.nworkers)

            if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
                loader = tqdm(loader, ncols=100)

            loss_meter = tnt.meter.AverageValueMeter()
            n_clusters_meter = tnt.meter.AverageValueMeter()
            BR_meter = tnt.meter.AverageValueMeter()
            BP_meter = tnt.meter.AverageValueMeter()
            CM_classes = metrics.ConfusionMatrix(dbinfo['classes'])

            # iterate over dataset in batches
            for bidx, (fname, edg_source, edg_target, is_transition, labels, objects, clouds_data, xyz) in enumerate(loader):

                if args.cuda:
                    is_transition = is_transition.to('cuda', non_blocking=True)
                    # labels = torch.from_numpy(labels).cuda()
                    objects = objects.to('cuda', non_blocking=True)
                    clouds, clouds_global, nei = clouds_data
                    clouds_data = (clouds.to('cuda', non_blocking=True), clouds_global.to('cuda', non_blocking=True), nei)

                embeddings = ptnCloudEmbedder.run_batch(model, *clouds_data, xyz)

                diff = compute_dist(embeddings, edg_source, edg_target, args.dist_type)

                if len(is_transition) > 1:
                    weights_loss, pred_components, pred_in_component = compute_weight_loss(args, embeddings, objects, edg_source, edg_target,
                                                                                           is_transition, diff, True, xyz)
                    loss1, loss2 = compute_loss(args, diff, is_transition, weights_loss)
                    loss = (loss1 + loss2) / weights_loss.shape[0]
                    pred_transition = pred_in_component[edg_source] != pred_in_component[edg_target]
                    per_pred = perfect_prediction(pred_components, labels)
                    CM_classes.count_predicted_batch(labels[:, 1:], per_pred)
                else:
                    loss = 0

                if len(is_transition) > 1:
                    loss_meter.add(loss.item())  # /weights_loss.sum().item())
                    is_transition = is_transition.cpu().numpy()
                    n_clusters_meter.add(len(pred_components))
                    BR_meter.add((is_transition.sum()) * compute_boundary_recall(is_transition, relax_edge_binary(pred_transition, edg_source,
                                                                                                                  edg_target, xyz.shape[0],
                                                                                                                  args.BR_tolerance)),
                                 n=is_transition.sum())
                    BP_meter.add((pred_transition.sum()) * compute_boundary_precision(relax_edge_binary(is_transition, edg_source,
                                                                                                        edg_target, xyz.shape[0],
                                                                                                        args.BR_tolerance), pred_transition),
                                 n=pred_transition.sum())
        CM = CM_classes.confusion_matrix
        return loss_meter.value()[0], n_clusters_meter.value()[0], 100 * CM.trace() / CM.sum(), BR_meter.value()[0], BP_meter.value()[0]
    def train(i_epoch):
        """ Trains for one epoch """
        #return 0
        model.train()
        loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=graph_collate, num_workers=args.nworkers, shuffle=True, drop_last=True)
        
        if logging.getLogger().getEffectiveLevel() > logging.DEBUG: loader = tqdm(loader, ncols=100)
    
        loss_meter = tnt.meter.AverageValueMeter()
        n_clusters_meter = tnt.meter.AverageValueMeter()

        t0 = time.time()
    
        for bidx, (fname, edg_source, edg_target, is_transition, labels, objects, clouds_data, xyz) in enumerate(loader):
            
            if args.cuda:
                is_transition = is_transition.to('cuda',non_blocking=True)
                #labels = torch.from_numpy(labels).cuda()
                objects = objects.to('cuda',non_blocking=True)
                clouds, clouds_global, nei = clouds_data
                clouds_data = (clouds.to('cuda',non_blocking=True),clouds_global.to('cuda',non_blocking=True),nei) 

            t_loader = 1000*(time.time()-t0)
            optimizer.zero_grad()
            t0 = time.time()

            embeddings = ptnCloudEmbedder.run_batch(model, *clouds_data, xyz)
            
            diff = compute_dist(embeddings, edg_source, edg_target, args.dist_type)
            
            weights_loss, pred_comp, in_comp = compute_weight_loss(args, embeddings, objects, edg_source, edg_target, is_transition, diff, True, xyz)
            
            loss1, loss2 = compute_loss(args, diff, is_transition, weights_loss)
            
            factor = 1000 #scaling for better usage of float precision
            
            loss = (loss1 + loss2) / weights_loss.shape[0]*factor
            
            loss.backward()
            
            if args.grad_clip>0:
                for p in model.parameters():
                    p.grad.data.clamp_(-args.grad_clip*factor, args.grad_clip*factor)
                    
            optimizer.step()

            t_trainer = 1000*(time.time()-t0)
            loss_meter.add(loss.item()/factor)#/weights_loss.mean().item())
            n_clusters_meter.add(embeddings.shape[0] / len(pred_comp))
            
            logging.debug('Batch loss %f, Loader time %f ms, Trainer time %f ms.', loss.item() / factor, t_loader, t_trainer)
            t0 = time.time()
            
        #return 0,0,0
        return loss_meter.value()[0], n_clusters_meter.value()[0]