Beispiel #1
0
def model_network(param_dict):
    """
    This model network consists of a spike source and a neuron (IF_curr_alpha). 
    The spike rate of the source and the weight can be specified in the 
    param_dict. Returns the number of spikes fired during 1000 ms of simulation.
    
    Parameters:
    param_dict - dictionary with keys
                 rate - the rate of the spike source (spikes/second)
                 weight - weight of the connection source -> neuron
                 
    Returns:
    dictionary with keys:
        source_rate - the rate of the spike source
        weight - weight of the connection source -> neuron
        neuron_rate - spike rate of the neuron
    """
    #set up the network
    from retina import Retina
    retina = Retina(param_dict['N'])
    params = retina.params
    params.update(param_dict)  # updates what changed in the dictionary
    # simulate the experiment and get its data
    data = retina.run(params)  #,verbose=False)
    neuron_rate = data['out_ON_DATA'].mean_rate()
    print neuron_rate
    # return everything, including the input parameters
    return {
        'snr': param_dict['snr'],
        'kernelseed': param_dict['kernelseed'],
        'neuron_rate': neuron_rate
    }
def model_network(param_dict):
    """
    This model network consists of a spike source and a neuron (IF_curr_alpha). 
    The spike rate of the source and the weight can be specified in the 
    param_dict. Returns the number of spikes fired during 1000 ms of simulation.
    
    Parameters:
    param_dict - dictionary with keys
                 rate - the rate of the spike source (spikes/second)
                 weight - weight of the connection source -> neuron
                 
    Returns:
    dictionary with keys:
        source_rate - the rate of the spike source
        weight - weight of the connection source -> neuron
        neuron_rate - spike rate of the neuron
    """ 
    #set up the network
    from retina import Retina
    retina = Retina(param_dict['N'])
    params = retina.params
    params.update(param_dict) # updates what changed in the dictionary
    # simulate the experiment and get its data
    data = retina.run(params)#,verbose=False)
    neuron_rate = data['out_ON_DATA'].mean_rate()
    print neuron_rate
    # return everything, including the input parameters
    return {'snr':param_dict['snr'], 
            'kernelseed':param_dict['kernelseed'], 
            'neuron_rate': neuron_rate}
Beispiel #3
0
def test_view_buffer():
    r = Retina()
    code = """for i in [1, 2, 3, 4]:
    print "The count is", i
    print "Done counting"""

    buf = file_to_text_buffer(StringIO(code))
    s = r.view_buffer(buf, x=7, y=1)
    assert_equal(s, "** ^The count **^")
Beispiel #4
0
def test_view_buffer():
    r = Retina()
    code = """for i in [1, 2, 3, 4]:
    print "The count is", i
    print "Done counting"""

    buf = file_to_text_buffer(StringIO(code))
    s = r.view_buffer(buf, x=7, y=1)
    assert_equal(s, "** ^The count **^")
Beispiel #5
0
 def __init__(self, weights, classes=['building']):
     self.net = Retina(classes).eval().cuda()
     chkpnt = torch.load(weights)
     self.net.load_state_dict(chkpnt['state_dict'])
     self.transform = transforms.Compose([
         transforms.Resize((300, 300)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
     ])
Beispiel #6
0
    def __init__(self,
                 text_buffer,
                 retina=None,
                 pos=(0, 0),
                 char_vis=(0.226, 0.404)):
        self.text_buffer = text_buffer
        self.retina = retina
        self.pos = pos
        self.char_vis = char_vis

        if self.retina is None:
            self.retina = Retina()
Beispiel #7
0
def test_view_string():
    r = Retina()

    # Empty string (whitespace for all retina slots)
    s = r.view_string("")
    assert_equal(s, "".join([Retina.LOW_WHITESPACE] * len(r.slots)))

    # Letters
    s = r.view_string(" Hello World")
    assert_equal(s, " ***lo World     ")

    # Numbers
    s = r.view_string("12 This is a test")
    assert_equal(s, "## *his is a ****")
Beispiel #8
0
def test_view_string():
    r = Retina()

    # Empty string (whitespace for all retina slots)
    s = r.view_string("")
    assert_equal(s, "".join([Retina.LOW_WHITESPACE] * len(r.slots)))

    # Letters
    s = r.view_string(" Hello World")
    assert_equal(s, " ***lo World     ")

    # Numbers
    s = r.view_string("12 This is a test")
    assert_equal(s, "## *his is a ****")
Beispiel #9
0
 def __init__(self, weights, classes=['building'], cuda = True):
     chkpnt = torch.load(weights)
     self.config = chkpnt['args']
     self.net = Retina(self.config).eval()
     self.net.load_state_dict(chkpnt['state_dict'])
     self.transform = transforms.Compose([
         transforms.Resize((self.config.model_input_size, self.config.model_input_size)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
     ])
     self.net = self.net.cuda()
     self.net.anchors.anchors = self.net.anchors.anchors.cuda()
     torch.set_default_tensor_type('torch.cuda.FloatTensor')
     self.cuda = cuda
Beispiel #10
0
class Eye:
    TIME_PREP = 150
    TIME_MOTOR = 50
    TIME_SACCADE = 20
    TIME_PER_DEGREE = 2

    def __init__(self,
                 text_buffer,
                 retina=None,
                 pos=(0, 0),
                 char_vis=(0.226, 0.404)):
        self.text_buffer = text_buffer
        self.retina = retina
        self.pos = pos
        self.char_vis = char_vis

        if self.retina is None:
            self.retina = Retina()

    def move_to(self, new_pos):
        x, y = new_pos
        time = Eye.TIME_PREP + Eye.TIME_MOTOR + Eye.TIME_SACCADE

        # Distance in degrees of visual angle
        dist = math.sqrt(((self.pos[0] - x) * self.char_vis[0])**2 +
                         ((self.pos[1] - y) * self.char_vis[1])**2)

        time += (dist * Eye.TIME_PER_DEGREE)
        self.pos = (x, y)

        return time

    def view(self):
        x, y = self.pos
        return self.retina.view_buffer(self.text_buffer, x, y)
Beispiel #11
0
 def __init__(self, content, retina=False):
     self.CAMERA_INITIAL_ANGLE_V = deg2rad(10.0)
     # TODO: 他の環境も指定できるようにする
     self.content = content
     self.env = Environment(self.content)
     self.egocentric_images = None
     self.allocentric_images = None
     self.retina = Retina() if retina else None
Beispiel #12
0
    def __init__(self, text_buffer=None, retina=None, pos=(0, 0),
            char_vis=(0.226, 0.404)):
        self.busy = False
        self.text_buffer = text_buffer
        self.retina = retina
        self.pos = pos
        self.char_vis = char_vis

        if self.retina is None:
            self.retina = Retina()
Beispiel #13
0
class MrChipsEye(ccm.Model):
    TIME_PREP = 0.150
    TIME_MOTOR = 0.050
    TIME_SACCADE = 0.020
    TIME_PER_DEGREE = 0.002

    def __init__(self,
                 text_buffer=None,
                 retina=None,
                 pos=(0, 0),
                 char_vis=(0.226, 0.404)):
        self.busy = False
        self.text_buffer = text_buffer
        self.retina = retina
        self.pos = pos
        self.char_vis = char_vis

        if self.retina is None:
            self.retina = Retina()

    def move_to(self, new_pos):
        if self.busy:
            return

        self.busy = True
        x, y = new_pos

        self.log._ = "Eye movement preparation"
        yield MrChipsEye.TIME_PREP

        self.log._ = "Eye movement motor preparation"
        yield MrChipsEye.TIME_MOTOR

        # Distance in degrees of visual angle
        dist = math.sqrt(((self.pos[0] - x) * self.char_vis[0])**2 +
                         ((self.pos[1] - y) * self.char_vis[1])**2)

        saccade_time = MrChipsEye.TIME_SACCADE + (dist *
                                                  MrChipsEye.TIME_PER_DEGREE)

        self.log._ = "Eye movement saccade"
        yield saccade_time
        self.log._ = "Finished eye movement"

        self.pos = (x, y)
        self.busy = False

    def view(self):
        assert self.text_buffer is not None
        x, y = self.pos
        return self.retina.view_buffer(self.text_buffer, x, y)
Beispiel #14
0
class MrChipsEye(ccm.Model):
    TIME_PREP       = 0.150
    TIME_MOTOR      = 0.050
    TIME_SACCADE    = 0.020
    TIME_PER_DEGREE = 0.002

    def __init__(self, text_buffer=None, retina=None, pos=(0, 0),
            char_vis=(0.226, 0.404)):
        self.busy = False
        self.text_buffer = text_buffer
        self.retina = retina
        self.pos = pos
        self.char_vis = char_vis

        if self.retina is None:
            self.retina = Retina()

    def move_to(self, new_pos):
        if self.busy:
            return

        self.busy = True
        x, y = new_pos

        self.log._ = "Eye movement preparation"
        yield MrChipsEye.TIME_PREP

        self.log._ = "Eye movement motor preparation"
        yield MrChipsEye.TIME_MOTOR

        # Distance in degrees of visual angle
        dist = math.sqrt(((self.pos[0] - x) * self.char_vis[0])**2 +
                         ((self.pos[1] - y) * self.char_vis[1])**2)

        saccade_time = MrChipsEye.TIME_SACCADE + (dist * MrChipsEye.TIME_PER_DEGREE)

        self.log._ = "Eye movement saccade"
        yield saccade_time
        self.log._ = "Finished eye movement"

        self.pos = (x, y)
        self.busy = False

    def view(self):
        assert self.text_buffer is not None
        x, y = self.pos
        return self.retina.view_buffer(self.text_buffer, x, y)
chip = nsetup.chips['mn256r1']
nsetup.mapper._init_fpga_mapper()
p = pyNCS.Population('', '')
p.populate_all(nsetup, 'mn256r1', 'excitatory')

# we use the first 128 neurons to receive input from the retina
inputpop = pyNCS.Population('','')
inputpop.populate_by_id(nsetup,'mn256r1', 'excitatory', np.linspace(0,255,256))  
#reset multiplexer
chip.configurator._set_multiplexer(0)
#init class synapses learning
#sl = SynapsesLearning(p, 'learning')
#matrix_learning = np.ones([256,256])

#init retina
ret = Retina(inputpop)
#program all excitatory synapses for the programmable syn
matrix_exc = np.ones([256,256])
nsetup.mapper._program_onchip_exc_inh(matrix_exc)
#set to zeros recurrent and learning synapses
matrix_off = np.zeros([256,256])
matrix_rec_off = np.zeros([256,512])
nsetup.mapper._program_onchip_programmable_connections(matrix_off)
nsetup.mapper._program_onchip_recurrent(matrix_rec_off)
matrix_weights = np.ones([256,256])
nsetup.mapper._program_onchip_weight_matrix_programmable(matrix_weights)

#ret.map_retina_to_mn256r1()
#ret.map_retina_to_mn256r1(inputpop) #this function map retina output pixels to mn256r1 programmable syn inputpop.synapses['programmable'].paddr[1::2]
#we first init the mappings
ret._init_fpga_mapper()
Beispiel #16
0
                        default=300,
                        type=int,
                        help='Input dimensions for SSD')
    args = parser.parse_args()

    if 'VOC' in args.train_data:
        dataset = VOC(args.train_data, transform=Transform(args.ssd_size))
    else:
        dataset = SpaceNet(args.train_data, transform=Transform(args.ssd_size))

    args.checkpoint_dir = os.path.join(args.save_folder,
                                       'ssd_%s' % datetime.now().isoformat())
    args.means = (104, 117, 123)  # only support voc now
    args.num_classes = len(dataset.classes) + 1
    args.stepvalues = (20, 50, 70)
    args.start_iter = 0
    args.writer = SummaryWriter()

    os.makedirs(args.save_folder, exist_ok=True)

    default_type = 'torch.cuda.FloatTensor' if args.cuda else 'torch.FloatTensor'
    torch.set_default_tensor_type(default_type)

    net = Retina(dataset.classes, args.ssd_size)

    if args.cuda:
        net = net.cuda()

    load_checkpoint(net, args)
    train(net, dataset, args)
        perceptron_pop.populate_by_id(nsetup, 'mn256r1', 'excitatory', perceptron_neu)

        net = Perceptrons(perceptron_pop,feature_pop) 
        net.matrix_learning_pot[:] = 0
        net.upload_config()
        #test
        #syn = feature_pop.synapses['programmable'][::16]
        #stim = syn.spiketrains_poisson(10)
        #nsetup.stimulate(stim,send_reset_event=False)

        #set up filters and connect retina
        inputpop = pyNCS.Population('','')
        inputpop.populate_by_id(nsetup,'mn256r1', 'excitatory', np.linspace(0,255,256))  
        #reset multiplexer
        chip.configurator._set_multiplexer(0)
        ret = Retina(inputpop)
        ret._init_fpga_mapper()
        pre_teach, post_teach, pre_address, post_address = ret.map_retina_to_mn256r1_randomproj()
        nsetup.chips['mn256r1'].load_parameters('biases/biases_wijlearning_ret_perceptrons_1.biases')
        
        #two different biases for teacher and inputs
        #matrix_w = np.zeros([256,256])
        #matrix_w[:,0:128]  = 2
        #matrix_w[:,128:256]  = 1
        #nsetup.mapper._program_onchip_programmable_connections(matrix_w)
        
        #retina, pre_address, post_address  = ret.map_retina_to_mn256r1_macro_pixels(syntype='learning')
        #on off retina nsetup.mapper._program_detail_mapping(2**6) on -> 7 
        is_configured = True      
        
    if measure_scope:
Beispiel #18
0
def test_view_line():
    r = Retina()
    s = r.view_line("x = [2, 8, 7, 9, -5, 0, 2]", 2)
    assert_equal(s, "- (#, 8, 7, 9. -#")
Beispiel #19
0
class RetinaNet:
    name = NAME

    @classmethod
    def mk_hash(cls, path):
        '''
        Create an MD5 hash from a models weight file.
        Arguments:
            path : str - path to RetinaNet checkpoint
        '''
        dirs = path.split('/')
        if 'retina_net' in dirs:
            dirs = dirs[dirs.index('retina_net'):]
            path = '/'.join(dirs)
        else:
            path = os.path.join('retina_net', path)

        md5 = hashlib.md5()
        md5.update(path.encode('utf-8'))
        return md5.hexdigest()

    @classmethod
    def zip_weights(cls, path, base_dir='./'):
        if os.path.splitext(path)[1] != '.pth':
            raise ValueError('Invalid checkpoint')

        dirs = path.split('/')

        res = {
            'name' : 'RetinaNet',
            'instance' : '_'.join(dirs[-2:]),
            'id' : cls.mk_hash(path)
        }

        zipfile = os.path.join(base_dir, res['id'] + '.zip')

        if os.path.exists(zipfile):
            os.remove(zipfile)

        weight_dir = os.path.dirname(path)

        with ZipFile(zipfile, 'w') as z:
            z.write(path, os.path.join(res['id'], os.path.basename(path)))

        return zipfile

    def __init__(self, weights, classes=['building'], cuda = True):
        chkpnt = torch.load(weights)
        self.config = chkpnt['args']
        self.net = Retina(self.config).eval()
        self.net.load_state_dict(chkpnt['state_dict'])
        self.transform = transforms.Compose([
            transforms.Resize((self.config.model_input_size, self.config.model_input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.net = self.net.cuda()
        self.net.anchors.anchors = self.net.anchors.anchors.cuda()
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        self.cuda = cuda

    def predict_image(self, image, eval_mode = False):
        """
        Infer buildings for a single image.
        Inputs:
            image :: n x m x 3 ndarray - Should be in RGB format
        """

        t0 = time.time()
        img = self.transform(image)
        if self.cuda:
            img = img.cuda()

        out = self.net(Variable(img.unsqueeze(0), requires_grad=False)).squeeze().data.cpu().numpy()
        total_time = time.time() - t0
        
        out = out[1] # ignore background class

        out[:, (1, 3)] = np.clip(out[:, (1, 3)] * image.width, a_min=0, a_max=image.width)
        out[:, (2, 4)] = np.clip(out[:, (2, 4)] * image.height, a_min=0, a_max=image.height)

        out = out[out[:, 0] > 0]

        return pandas.DataFrame(out, columns=['score', 'x1' ,'y1', 'x2', 'y2'])

    def predict_all(self, test_boxes_file, batch_size=8, data_dir = None):
        if data_dir is None:
            data_dir = os.path.join(os.path.dirname(test_boxes_file))
        
        annos = json.load(open(test_boxes_file))

        total_time = 0.0

        for batch in range(0, len(annos), batch_size):
            images,  sizes = [], []
            for i in range(min(batch_size, len(annos) - batch)):
                img = Image.open(os.path.join(data_dir, annos[batch + i]['image_path']))
                images.append(self.transform(img))
                sizes.append(torch.Tensor([img.width, img.height]))

            images = torch.stack(images)
            sizes = torch.stack(sizes)

            if self.cuda:
                images = images.cuda()
                sizes = sizes.cuda()

            out = self.net(Variable(images, requires_grad=False)).data

            hws = torch.cat([sizes, sizes], dim=1).view(-1, 1, 1, 4).expand(-1, out.shape[1], out.shape[2], -1)

            out[:, :, :, 1:] *= hws
            out = out[:, 1, :, :].cpu().numpy()

            for i, detections in enumerate(out):
                anno = annos[batch + i]
                pred = cv2.imread('../data/' + anno['image_path'])

                detections = detections[detections[:, 0] > 0]
                df = pandas.DataFrame(detections, columns=['score', 'x1', 'y1', 'x2', 'y2'])
                df['image_id'] = anno['image_path']

                truth = pred.copy()

                for box in df[['x1', 'y1', 'x2', 'y2']].values.round().astype(int):
                    cv2.rectangle(pred, tuple(box[:2]), tuple(box[2:4]), (0,0,255))

                for r in anno['rects']:
                    box = list(map(lambda x: int(r[x]), ['x1', 'y1', 'x2', 'y2']))
                    cv2.rectangle(truth, tuple(box[:2]), tuple(box[2:]), (0, 0, 255))

                data = np.concatenate([pred, truth], axis=1)
                cv2.imwrite('samples/image_%d.jpg' % (batch + i), data)

                yield df
Beispiel #20
0
def test_view_line():
    r = Retina()
    s = r.view_line("x = [2, 8, 7, 9, -5, 0, 2]", 2)
    assert_equal(s, "- (#, 8, 7, 9. -#")
Beispiel #21
0
class SSD:
    name = NAME

    @classmethod
    def mk_hash(cls, path):
        '''
        Create an MD5 hash from a models weight file.
        Arguments:
            path : str - path to TensorBox checkpoint
        '''
        dirs = path.split('/')
        if 'ssd.pytorch' in dirs:
            dirs = dirs[dirs.index('ssd.pytorch'):]
            path = '/'.join(dirs)
        else:
            path = os.path.join('ssd.pytorch', path)

        md5 = hashlib.md5()
        md5.update(path.encode('utf-8'))
        return md5.hexdigest()

    @classmethod
    def zip_weights(cls, path, base_dir='./'):
        if os.path.splitext(path)[1] != '.pth':
            raise ValueError('Invalid checkpoint')

        dirs = path.split('/')

        res = {
            'name': 'TensorBox',
            'instance': '_'.join(dirs[-2:]),
            'id': cls.mk_hash(path)
        }

        zipfile = os.path.join(base_dir, res['id'] + '.zip')

        if os.path.exists(zipfile):
            os.remove(zipfile)

        weight_dir = os.path.dirname(path)

        with ZipFile(zipfile, 'w') as z:
            z.write(path, os.path.join(res['id'], os.path.basename(file)))

        return zipfile

    def __init__(self, weights, classes=['building']):
        self.net = Retina(classes).eval().cuda()
        chkpnt = torch.load(weights)
        self.net.load_state_dict(chkpnt['state_dict'])
        self.transform = transforms.Compose([
            transforms.Resize((300, 300)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def predict_image(self, image, threshold, eval_mode=False):
        """
        Infer buildings for a single image.
        Inputs:
            image :: n x m x 3 ndarray - Should be in RGB format
        """

        t0 = time.time()
        img = self.transform(image)
        out = self.net(Variable(img.unsqueeze(0).cuda(),
                                volatile=True)).squeeze().data.cpu()
        total_time = time.time() - t0

        scores = out[:, :,
                     0]  # class X top K X (score, minx, miny, maxx, maxy)

        max_scores, inds = scores.max(dim=0)

        linear = torch.arange(0, out.shape[1]).long()
        boxes = out[inds, linear].numpy()
        boxes[:, (1, 3)] = np.clip(boxes[:, (1, 3)] * image.width,
                                   a_min=0,
                                   a_max=image.width)
        boxes[:, (2, 4)] = np.clip(boxes[:, (2, 4)] * image.height,
                                   a_min=0,
                                   a_max=image.height)

        df = pandas.DataFrame(boxes, columns=['score', 'x1', 'y1', 'x2', 'y2'])

        if eval_mode:
            return df[df['score'] > threshold], df, total_time
        else:
            return df[df['score'] > threshold]

        pdb.set_trace()

    def predict_all(self, test_boxes_file, threshold, data_dir=None):
        test_boxes = json.load(open(test_boxes_file))
        true_annolist = al.parse(test_boxes_file)
        if data_dir is None:
            data_dir = os.path.join(os.path.dirname(test_boxes_file))

        total_time = 0.0

        for i in range(len(true_annolist)):
            true_anno = true_annolist[i]

            orig_img = imread('%s/%s' %
                              (data_dir, true_anno.imageName))[:, :, :3]

            pred, all_rects, time = self.predict_image(orig_img,
                                                       threshold,
                                                       eval_mode=True)

            pred['image_id'] = i
            all_rects['image_id'] = i

            yield pred, all_rects, test_boxes[i]
        color = label_color(labels[i])
        draw_box(img, box, color=color)

        caption = "{} {:.3f}".format(labels_to_names[labels[i]], scores[i])
        draw_caption(img, box, caption)

    return img


if __name__ == '__main__':
    labels_path = './models/coco.names'
    labels_to_names = open(labels_path).read().strip().split("\n")

    robot = Robot()
    yolo_detector = Yolo()
    retina_detector = Retina()

    robot.start()
    while True:
        frame = robot.getFrame()

        frame = draw(yolo_detector, retina_detector, frame)
        cv2.imshow('ai2thor', frame)

        key = chr(cv2.waitKey(0))
        if key == 'q':
            break
        robot.apply(key)

    robot.stop()
    cv2.destroyAllWindows()
Beispiel #23
0
        net = Perceptrons(perceptron_pop, feature_pop)
        net.matrix_learning_pot[:] = 0
        net.upload_config()
        #test
        #syn = feature_pop.synapses['programmable'][::16]
        #stim = syn.spiketrains_poisson(10)
        #nsetup.stimulate(stim,send_reset_event=False)

        #set up filters and connect retina
        inputpop = pyNCS.Population('', '')
        inputpop.populate_by_id(nsetup, 'mn256r1', 'excitatory',
                                np.linspace(0, 255, 256))
        #reset multiplexer
        chip.configurator._set_multiplexer(0)
        ret = Retina(inputpop)
        ret._init_fpga_mapper()
        pre_teach, post_teach, pre_address, post_address = ret.map_retina_to_mn256r1_randomproj(
        )
        nsetup.chips['mn256r1'].load_parameters(
            'biases/biases_wijlearning_ret_perceptrons_1.biases')

        #two different biases for teacher and inputs
        #matrix_w = np.zeros([256,256])
        #matrix_w[:,0:128]  = 2
        #matrix_w[:,128:256]  = 1
        #nsetup.mapper._program_onchip_programmable_connections(matrix_w)

        #retina, pre_address, post_address  = ret.map_retina_to_mn256r1_macro_pixels(syntype='learning')
        #on off retina nsetup.mapper._program_detail_mapping(2**6) on -> 7
        is_configured = True
Beispiel #24
0
                        help='Path to training data')
    parser.add_argument('--data_dir',
                        default=None,
                        help='Directory of training data')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    args.cuda = args.gpu is not None

    default_type = 'torch.cuda.FloatTensor' if args.cuda else 'torch.FloatTensor'
    torch.set_default_tensor_type(default_type)

    DS_Class = VOC if 'VOC' in args.train_data else SpaceNet

    net = Retina(args)
    dataset = DS_Class(args.train_data,
                       Transform(args, net.anchors),
                       args,
                       root_dir=args.data_dir)

    args.checkpoint_dir = os.path.join(args.save_folder,
                                       'ssd_%s' % datetime.now().isoformat())
    args.start_iter = 0

    if args.resume:
        args.checkpoint_dir = os.path.dirname(args.resume)

    os.makedirs(args.save_folder, exist_ok=True)

    if args.cuda: