예제 #1
0
    def train(self, rank, start_time, return_dict):

        rns = torch.randint(0, 2**32, torch.Size([10]))
        best_qual = -np.inf
        best_seed = None
        for i, rn in enumerate(rns):
            writer = SummaryWriter(
                logdir=os.path.join(self.save_dir, 'logs',
                                    str(i) + '_' + str(rn.item())))
            self.global_count.reset()
            self.global_writer_loss_count.reset()
            self.global_writer_quality_count.reset()
            self.global_win_event_count.reset()
            self.action_stats_count.reset()

            set_seed_everywhere(rn.item())
            qual = self.train_step(rank, start_time, return_dict, writer)
            if qual < best_qual:
                best_qual = qual
                best_seed = rns

        res = 'best seed is: ' + str(best_seed) + " with a qual of: " + str(
            best_qual)
        print(res)

        with open(os.path.join(self.save_dir, 'result.txt'), "w") as info:
            info.write(res)
예제 #2
0
파일: sac.py 프로젝트: paulhfu/RLForSeg
    def train_and_explore(self, rn):
        self.global_count.reset()

        set_seed_everywhere(rn)
        wandb.config.random_seed = rn
        if self.cfg.verbose:
            print('###### start training ######')
            print('Running on device: ', self.device)
            print('found ', self.train_dset.length, " training data patches")
            print('found ', self.val_dset.length, "validation data patches")
            print('training with seed: ' + str(rn))
        explorers = []
        for i in range(self.cfg.n_explorers):
            explorers.append(threading.Thread(target=self.explore))
        [explorer.start() for explorer in explorers]

        self.memory.is_full_event.wait()
        trainer = threading.Thread(target=self.train_until_finished)
        trainer.start()

        trainer.join()
        self.global_count.set(self.cfg.T_max + self.cfg.mem_size + 4)
        [explorer.join() for explorer in explorers]
        self.memory.clear()
        del self.memory
        # torch.save(self.model.state_dict(), os.path.join(wandb.run.dir, "last_checkpoint_agent.pth"))
        if self.cfg.verbose:
            print('\n\n###### training finished ######')
        return
예제 #3
0
    def validate_seeds(self):
        rns = torch.randint(0, 2**32, torch.Size([10]))
        best_qual = np.inf
        best_seed = -1
        with open(os.path.join(self.save_dir, 'result.txt'), "w") as info:
            for i, rn in enumerate(rns):
                set_seed_everywhere(rn.item())
                abs_diffs, rel_diffs, mean_size, mean_n_larger_thresh = self.validate(
                )
                qual = sum(rel_diffs) / len(rel_diffs)
                if qual < best_qual:
                    best_qual = qual
                    best_seed = rn.item()

                res = '\nseed is: ' + str(rn.item()) + " with abs_diffs of: " \
                      + str(sum(abs_diffs)) + ' :: and rel_diffs of: ' + str(qual) + " Num of diffs larger than 0.5: " \
                      + str(mean_n_larger_thresh) + "/" + str(mean_size)
                print(res)
                info.write(res)

        return best_seed, best_qual
예제 #4
0
    def setup(self, rank, world_size):
        # BLAS setup
        os.environ['OMP_NUM_THREADS'] = '1'
        os.environ['MKL_NUM_THREADS'] = '1'

        # os.environ["CUDA_VISIBLE_DEVICES"] = "6"
        assert torch.cuda.device_count() == 1
        torch.set_default_tensor_type('torch.FloatTensor')
        # Detect if we have a GPU available
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = self.args.master_port
        # os.environ['GLOO_SOCKET_IFNAME'] = 'eno1'

        # initialize the process group
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

        # Explicitly setting seed to make sure that models created in two processes
        # start from same random weights and biases.
        seed = torch.randint(0, 2**32, torch.Size([5])).median()
        set_seed_everywhere(seed.item())
예제 #5
0
    def train(self, rank, start_time, return_dict, rn):

        self.log_dir = os.path.join(self.save_dir, 'logs', '_' + str(rn))
        writer = None
        if rank == 0:
            writer = SummaryWriter(logdir=self.log_dir)
            writer.add_text("config", self.cfg.pretty(), 0)
            copyfile(os.path.join(self.save_dir, 'runtime_cfg.yaml'),
                     os.path.join(self.log_dir, 'runtime_cfg.yaml'))

            self.global_count.reset()
            self.global_writer_loss_count.reset()
            self.global_writer_quality_count.reset()
            self.action_stats_count.reset()
            self.global_writer_count.reset()

        set_seed_everywhere(rn)
        if rank == 0:
            print('training with seed: ' + str(rn))
        score = self.train_step(rank, writer)
        if rank == 0:
            return_dict['score'] = score
            del self.memory
        return
예제 #6
0
from affogato.affinities import compute_affinities
from affogato.segmentation.mws import get_valid_edges
from skimage import draw
from skimage.filters import gaussian
import elf
import nifty

from data.mtx_wtsd import get_sp_graph
from mutex_watershed import compute_mws_segmentation_cstm
from utils.affinities import get_naive_affinities, get_edge_features_1d
from utils.general import calculate_gt_edge_costs, set_seed_everywhere
from data.spg_dset import SpgDset
import matplotlib.pyplot as plt
from matplotlib import cm

set_seed_everywhere(10)


def get_pix_data(length=50000, shape=(128, 128), radius=72):
    dim = (256, 256)
    edge_offsets = [
        [0, -1],
        [-1, 0],
        # direct 3d nhood for attractive edges
        # [0, -1], [-1, 0]]
        [-3, 0],
        [0, -3],
        [-6, 0],
        [0, -6]
    ]
    sep_chnl = 2
예제 #7
0
import pandas as pd
from tqdm import notebook
import importlib
import pprint
import nltk
import datetime
import os
from argparse import Namespace
import re
from collections import Counter

import utils.general as general_utils
import utils.trac2020 as trac_utils
import utils.transformer.data as transformer_data_utils
import utils.transformer.general as transformer_general_utils
general_utils.set_seed_everywhere()  #set the seed for reproducibility

import logging
logging.basicConfig(level=logging.INFO)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Import RAdam and Lookahead
from radam.radam import RAdam
from lookahead.optimizer import Lookahead

from transformers import XLMRobertaTokenizer, XLMRobertaModel