Ejemplo n.º 1
0
    def test_times_two_function(self):
        # Hyperparameters
        num_levels = 2
        num_nodes_at_level = {0: 2, 1: 2}
        num_ops_at_level = {0: LEN_SIMPLE_OPS, 1: 1}
        num_epochs = 100

        # Initialize tensorboard writer
        dt_string = datetime.now().strftime("%d-%m-%Y--%H-%M-%S")
        writer = SummaryWriter('test/test_double_func/' + str(dt_string) + "/")

        # Define model
        model = ModelController(num_levels=num_levels,
                                num_nodes_at_level=num_nodes_at_level,
                                num_ops_at_level=num_ops_at_level,
                                primitives=SIMPLE_OPS,
                                channels_in=1,
                                channels_start=1,
                                stem_multiplier=1,
                                num_classes=1,
                                loss_criterion=nn.L1Loss(),
                                writer=writer,
                                test_mode=True)

        # Input
        x = tensor([[
            # feature 1
            [[1.]]
        ]])

        # Expected output
        y = tensor([[
            # feature 1
            [[2.]]
        ]])

        # Alpha Optimizer - one for each level
        alpha_optim = []
        for level in range(0, num_levels):
            alpha_optim.append(
                torch.optim.Adam(params=model.get_alpha_level(level), lr=0.1))

        for _ in range(0, num_epochs):
            # Alpha Gradient Steps for each level
            for level in range(0, num_levels):
                alpha_optim[level].zero_grad()
                loss = model.loss_criterion(model(x), y)
                print(loss)
                loss.backward()
                alpha_optim[level].step()
Ejemplo n.º 2
0
    def update(cls,page,task_id):
        if not page.user:
            page.redirect("/")
        
        name = page.request.get('name')
        start_date_raw = page.request.get('start_date')
        end_date_raw = page.request.get('end_date')

        freq_unit = page.request.get('freq_unit')
        freq_value = int(page.request.get('freq_value'))
        weekly_list = None
        note = page.request.get('note')
        
        if freq_unit == task.FreqUnit.Weekly:
            weekly_list = task.WeeklyList.cvtListToStr(page.request.get('weekly_list'))

        if name and start_date_raw and end_date_raw:
            logging.warning(start_date_raw)
#            start_date = datetime.strptime(start_date_raw, "%m/%d/%Y").date()
            start_date = datetime.strptime(start_date_raw, "%m/%d/%Y").date()
            end_date = datetime.strptime(end_date_raw, "%m/%d/%Y").date()
            
            t = ModelController.object_by_id(task.Task, task_id)
            t.name = name
            t.start_date = start_date
            t.end_date = end_date
            t.freq_unit = freq_unit
            t.freq_value = freq_value
            t.weekly_list = weekly_list
            t.note = note
            cls.add_object(t)
            page.redirect("/task/%s" % task_id)
        else:
            error = "Please specify the task name, start date, and end date!"
            page.render("task_create3.html", name=name, note=note, error_task=error)          
Ejemplo n.º 3
0
 def create(cls, page):
     if not page.user:
         page.redirect('/')
     
     uid = page.read_secure_cookie('user_id')
     exist_user = cls.object_by_id(models.User, uid)
     subject = page.request.get('subject')
     start_date_raw = page.request.get('start_date')
     end_date_raw = page.request.get('end_date')
     content = page.request.get('content')
     flag = False
     
     
     if subject and start_date_raw and end_date_raw and content:
         start_date = datetime.strptime(start_date_raw, "%m/%d/%Y").date()
         end_date = datetime.strptime(end_date_raw, "%m/%d/%Y").date()
         g = goal.Goal(parent = ModelController.object_parent(goal.Goal), subject = subject,
                       start_date = start_date, end_date = end_date,
                       content = content, author = exist_user, flag = flag)
         gid = cls.add_object(g)
         cls.update_mc_list(page.user, g, '-create_date') # update the memcache for all posts
         page.redirect('/goal/%s' % gid)
     else:
         error = "Please specify the subject, start date, end date and description!"
         page.render("goal_create.html", subject=subject, content=content, error=error)
Ejemplo n.º 4
0
    def create(cls,page):
        if not page.user:
            page.redirect("/")
        
        gid = page.request.get('gid')
        page.goal = ModelController.object_by_id(goal.Goal, gid)
        name = page.request.get('name')
        start_date_raw = page.request.get('start_date')
        end_date_raw = page.request.get('end_date')

        freq_unit = page.request.get('freq_unit')
        freq_value_raw = page.request.get('freq_value')
        weekly_list = None
        note = page.request.get('note')
        
        if freq_unit == task.FreqUnit.Weekly:
            weekly_list = task.WeeklyList.cvtListToStr(page.request.get('weekly_list'))

        if name and start_date_raw and end_date_raw and freq_value_raw:         
            start_date = datetime.strptime(start_date_raw, "%m/%d/%Y").date()
            end_date = datetime.strptime(end_date_raw, "%m/%d/%Y").date()
            freq_value = int(freq_value_raw)
            if start_date>=page.goal.start_date and end_date<=page.goal.end_date:
                newTask = task.Task(parent = ModelController.object_parent(task.Task),
                            name = name,
                            goal = page.goal,    #suppose the current goal is saved
                            start_date = start_date,
                            end_date = end_date,
                            freq_unit = freq_unit,
                            freq_value = freq_value,
                            weekly_list = weekly_list,
                            note = note)
                tid = ModelController.add_object(newTask)
                ModelController.update_mc_list(page.goal, newTask, 'create_date')
                page.redirect("/goal/%s" % gid)
            else:
                error = "task date cannot exceed the span of this goal"
                page.render("task_create3.html", 
                        name=name, 
                        note=note, 
                        error_task=error)
        else:
            error = "Please specify the task name, start date, and end date!"
            page.render("task_create3.html", 
                        name=name, 
                        note=note, 
                        error_task=error)
Ejemplo n.º 5
0
def start_training(trace_history, room_dst, person_loc, approach_person,
                   sound_wave_model):
    rc = RobotController()
    mc = ModelController()

    for i, trace in enumerate(trace_history):
        print("-" * 60)
        print("Current Time: {}, Sample: {}".format(rospy.get_rostime().secs,
                                                    i))

        target_room = trace["target"]
        # robot back to origin
        explore_room(rc, room_dst, trace["origin"])
        # move a person to dst and emit sound
        px, py, pidx = set_person_pose(mc, person_loc, target_room)
        mc.spawn_model("sound_wave", sound_wave_model, px, py, 2)
        rospy.sleep(3)
        mc.delete_model("sound_wave")

        # robot actively explore the room according to the ranking result
        for next_room in trace["trace"]:
            if next_room == target_room:
                print("Sample {} find target room: {}".format(i, next_room))
                app_pos = approach_person[str(target_room)][pidx]
                rc.goto(app_pos["x"], app_pos["y"], app_pos["yaw"])
                rospy.sleep(1)
            else:
                print("Sample {} explore room: {}".format(i, next_room))
                explore_room(rc, room_dst, next_room)
Ejemplo n.º 6
0
def start_game():
    # Instantiate the ModelController
    model_controller = ModelController(GAME_FILE_DIRECTORY)

    # Instantiate the NavigationController
    navigation_controller = NavigationController()

    # Instantiate the root SceneController - This is the
    # first scene in the hierarchy.
    root_scene_controller = SceneController()
    root_scene_controller.delegate = navigation_controller
    root_scene_controller.model_controller = model_controller
    root_scene_controller.scene = \
        model_controller.game_data[ROOT_SCENE_CONTROLLER_ID]

    # Present the first scene
    root_scene_controller.delegate.\
        push_scene_controller(root_scene_controller)
Ejemplo n.º 7
0
 def update(cls, page, goal_id):
     if not page.user:
         page.redirect('/')
     
     subject = page.request.get('subject')
     start_date_raw = page.request.get('start_date')
     end_date_raw = page.request.get('end_date')
     content = page.request.get('content')
     
     
     if subject and start_date_raw and end_date_raw and content:
         start_date = datetime.strptime(start_date_raw, "%m/%d/%Y").date()
         end_date = datetime.strptime(end_date_raw, "%m/%d/%Y").date()
         g = ModelController.object_by_id(goal.Goal, goal_id)
         g.subject = subject;
         g.start_date = start_date;
         g.end_date = end_date;
         gid = cls.add_object(g)
         cls.update_mc_list(page.user, g, '-create_date') # update the memcache for all posts
         page.redirect('/goal/%s' % gid)
     else:
         error = "Please specify the subject, start date, end date and description!"
         page.render("goal_create.html", subject=subject, content=content, error=error)
    def run(self):
        # Get Data & MetaData
        input_size, input_channels, num_classes, train_data = get_data(
            dataset_name=config.DATASET,
            data_path=config.DATAPATH,
            cutout_length=16,
            test=False)

        # Set Loss Criterion
        loss_criterion = nn.CrossEntropyLoss()
        if torch.cuda.is_available():
            loss_criterion = loss_criterion.cuda()

        # Ensure num of ops at level 0 = num primitives
        config.NUM_OPS_AT_LEVEL[0] = LEN_OPS

        # Train / Validation Split
        n_train = (len(train_data) // 100) * config.PERCENTAGE_OF_DATA
        split = n_train // 2
        indices = list(range(n_train))
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:])
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.BATCH_SIZE,
            sampler=train_sampler,
            num_workers=config.NUM_DOWNLOAD_WORKERS,
            pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.BATCH_SIZE,
            sampler=valid_sampler,
            num_workers=config.NUM_DOWNLOAD_WORKERS,
            pin_memory=True)

        # Register Signal Handler for interrupts & kills
        signal.signal(signal.SIGINT, self.terminate)
        ''' 
        Search - Weight Training and Alpha Training
        '''

        # Initialize weight training model i.e. only 1 op at 2nd highest level
        self.model = ModelController(
            num_levels=config.NUM_LEVELS,
            num_nodes_at_level=config.NUM_NODES_AT_LEVEL,
            num_ops_at_level=config.NUM_OPS_AT_LEVEL,
            primitives=OPS,
            channels_in=input_channels,
            channels_start=config.CHANNELS_START,
            stem_multiplier=config.STEM_MULTIPLIER,
            num_classes=num_classes,
            num_cells=config.NUM_CELLS,
            loss_criterion=loss_criterion,
            writer=self.writer)

        # Transfer model to GPU
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            # Optimize if possible
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.enabled = True

        # Weights Optimizer
        w_optim = torch.optim.SGD(params=self.model.get_weights(),
                                  lr=config.WEIGHTS_LR,
                                  momentum=config.WEIGHTS_MOMENTUM,
                                  weight_decay=config.WEIGHTS_WEIGHT_DECAY)
        w_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            w_optim, config.EPOCHS, eta_min=config.WEIGHTS_LR_MIN)

        # gumbel softmax temperature scheduler
        temp_scheduler = None
        if config.USE_GUMBEL_SOFTMAX:
            temp_scheduler = LinearTempScheduler(
                n_epochs=config.EPOCHS,
                starting_temp=config.ALPHA_STARTING_TEMP,
                final_temp=config.ALPHA_MIN_TEMP)

        # Alpha Optimizer - one for each level
        alpha_optim = []
        # If trying to simulate DARTS don't bother with alpha optim for higher level
        if config.NUM_NODES_AT_LEVEL[0] == 2:
            num_levels = 1
        else:
            num_levels = config.NUM_LEVELS
        for level in range(0, num_levels):
            alpha_optim.append(
                torch.optim.Adam(params=self.model.get_alpha_level(level),
                                 lr=config.ALPHA_LR[level],
                                 weight_decay=config.ALPHA_WEIGHT_DECAY,
                                 betas=config.ALPHA_MOMENTUM))

        # If loading from checkpoint, replace modelController's model with checkpoint model
        start_epoch = 0

        if config.LOAD_FROM_CHECKPOINT is not None:
            self.model, w_optim, w_lr_scheduler, alpha_optim, start_epoch = load_checkpoint(
                self.model, w_optim, w_lr_scheduler, alpha_optim,
                os.path.join(config.CHECKPOINT_PATH,
                             config.LOAD_FROM_CHECKPOINT))
            print("Loaded Checkpoint:", config.LOAD_FROM_CHECKPOINT)

        # Training Loop
        best_top1 = 0.
        for epoch in range(start_epoch, config.EPOCHS):
            lr = w_lr_scheduler.get_lr()[0]
            print("W Learning Rate:", lr)

            # Attemp to get temperature and step temp scheduler
            temp = None
            if temp_scheduler is not None:
                temp = temp_scheduler.step()

            # Put into weight training mode - turn off gradient for alpha
            self.model.weight_training_mode()

            # Weight Training
            self.train_weights(train_loader=train_loader,
                               model=self.model,
                               w_optim=w_optim,
                               epoch=epoch,
                               lr=lr,
                               temp=temp)

            # GPU Memory Allocated for Model in Weight Sharing Phase
            if epoch == 0:
                try:
                    print(
                        "Weight Training Phase: Max GPU Memory Used",
                        torch.cuda.max_memory_allocated() /
                        (1024 * 1024 * 1024), "GB")
                except:
                    print("Unable to retrieve memory data")

            # Turn off gradient for weight params
            self.model.alpha_training_mode()

            # Alpha Training / Validation
            top1 = self.train_alpha(valid_loader=valid_loader,
                                    model=self.model,
                                    alpha_optim=alpha_optim,
                                    epoch=epoch,
                                    lr=lr,
                                    temp=temp)

            # Save Checkpoint
            if best_top1 < top1:
                best_top1 = top1
                is_best = True
            else:
                is_best = False
            print("Saving checkpoint")
            save_checkpoint(
                self.model, epoch, w_optim, w_lr_scheduler, alpha_optim,
                os.path.join(config.CHECKPOINT_PATH, self.dt_string), is_best)

            # Weight Learning Rate Step
            w_lr_scheduler.step()

            # GPU Memory Allocated for Model
            if epoch == 0:
                try:
                    print(
                        "Alpha Training: Max GPU Memory Used",
                        torch.cuda.max_memory_allocated() /
                        (1024 * 1024 * 1024), "GB")
                except:
                    print("Unable to print memory data")

        # Log Best Accuracy so far
        print("Final best Prec@1 = {:.4%}".format(best_top1))

        # Terminate
        self.terminate()
    def train_alpha(self,
                    valid_loader,
                    model: ModelController,
                    alpha_optim,
                    epoch,
                    lr,
                    temp=None):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_step = epoch * len(valid_loader)

        # Log LR
        self.writer.add_scalar('train/lr', lr, epoch)

        # Prepares the model for training - 'training mode'
        model.train()

        for step, (val_X, val_y) in enumerate(valid_loader):
            N = val_X.size(0)
            if torch.cuda.is_available():
                val_X = val_X.cuda(non_blocking=True)
                val_y = val_y.cuda(non_blocking=True)

            # Alpha Gradient Steps for each level
            for level in range(len(alpha_optim)):
                alpha_optim[level].zero_grad()
            logits = model(val_X, temp=temp)
            loss = model.loss_criterion(logits, val_y)
            loss.backward()
            for level in range(len(alpha_optim)):
                alpha_optim[level].step()

            prec1, prec5 = accuracy(logits, val_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.PRINT_STEP_FREQUENCY == 0 or step == len(
                    valid_loader) - 1:
                print(
                    datetime.now(),
                    "Alpha Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.EPOCHS,
                        step,
                        len(valid_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            self.writer.add_scalar('val/loss', losses.avg, cur_step)
            self.writer.add_scalar('val/top1', top1.avg, cur_step)
            self.writer.add_scalar('val/top5', top5.avg, cur_step)
            cur_step += 1

        print(
            "Alpha Train (Uses Validation Loss): [{:2d}/{}] Final Prec@1 {:.4%}"
            .format(epoch + 1, config.EPOCHS, top1.avg))
        return top1.avg
    def train_weights(self,
                      train_loader,
                      model: ModelController,
                      w_optim,
                      epoch,
                      lr,
                      temp=None):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_step = epoch * len(train_loader)

        # Log LR
        self.writer.add_scalar('train/lr', lr, epoch)

        # Prepares the model for training - 'training mode'
        model.train()

        for step, (trn_X, trn_y) in enumerate(train_loader):
            N = trn_X.size(0)
            if torch.cuda.is_available():
                trn_X = trn_X.cuda(non_blocking=True)
                trn_y = trn_y.cuda(non_blocking=True)

            # Weights Step
            w_optim.zero_grad()
            logits = model(trn_X, temp=temp)
            loss = model.loss_criterion(logits, trn_y)
            loss.backward()

            # gradient clipping
            nn.utils.clip_grad_norm_(model.get_weights(),
                                     config.WEIGHTS_GRADIENT_CLIP)
            w_optim.step()

            prec1, prec5 = accuracy(logits, trn_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.PRINT_STEP_FREQUENCY == 0 or step == len(
                    train_loader) - 1:
                print(
                    datetime.now(),
                    "Weight Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.EPOCHS,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            self.writer.add_scalar('train/loss', loss.item(), cur_step)
            self.writer.add_scalar('train/top1', prec1.item(), cur_step)
            self.writer.add_scalar('train/top5', prec5.item(), cur_step)
            cur_step += 1

        print("Weight Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, config.EPOCHS, top1.avg))
class MNAS:
    def __init__(self):
        if config.LOAD_FROM_CHECKPOINT is None:
            self.dt_string = datetime.now().strftime("%d-%m-%Y--%H-%M-%S")
        else:
            self.dt_string = config.LOAD_FROM_CHECKPOINT + "-" + datetime.now(
            ).strftime("%d-%m-%Y--%H-%M-%S")
        self.writer = SummaryWriter(config.LOGDIR + "/" + config.DATASET +
                                    "/" + str(self.dt_string) + "/")
        self.num_levels = config.NUM_LEVELS

        # Set gpu device if cuda is available
        if torch.cuda.is_available():
            torch.cuda.set_device(config.gpus[0])

        # Write config to tensorboard
        hparams = {}
        for key in config.__dict__:
            if type(config.__dict__[key]) is dict or type(
                    config.__dict__[key]) is list:
                hparams[key] = str(config.__dict__[key])
            else:
                hparams[key] = config.__dict__[key]

        # Print config to logs
        pprint.pprint(hparams)

        # Seed for reproducibility
        torch.manual_seed(config.SEED)
        random.seed(config.SEED)

    def run(self):
        # Get Data & MetaData
        input_size, input_channels, num_classes, train_data = get_data(
            dataset_name=config.DATASET,
            data_path=config.DATAPATH,
            cutout_length=16,
            test=False)

        # Set Loss Criterion
        loss_criterion = nn.CrossEntropyLoss()
        if torch.cuda.is_available():
            loss_criterion = loss_criterion.cuda()

        # Ensure num of ops at level 0 = num primitives
        config.NUM_OPS_AT_LEVEL[0] = LEN_OPS

        # Train / Validation Split
        n_train = (len(train_data) // 100) * config.PERCENTAGE_OF_DATA
        split = n_train // 2
        indices = list(range(n_train))
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[:split])
        valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:])
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.BATCH_SIZE,
            sampler=train_sampler,
            num_workers=config.NUM_DOWNLOAD_WORKERS,
            pin_memory=True)
        valid_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=config.BATCH_SIZE,
            sampler=valid_sampler,
            num_workers=config.NUM_DOWNLOAD_WORKERS,
            pin_memory=True)

        # Register Signal Handler for interrupts & kills
        signal.signal(signal.SIGINT, self.terminate)
        ''' 
        Search - Weight Training and Alpha Training
        '''

        # Initialize weight training model i.e. only 1 op at 2nd highest level
        self.model = ModelController(
            num_levels=config.NUM_LEVELS,
            num_nodes_at_level=config.NUM_NODES_AT_LEVEL,
            num_ops_at_level=config.NUM_OPS_AT_LEVEL,
            primitives=OPS,
            channels_in=input_channels,
            channels_start=config.CHANNELS_START,
            stem_multiplier=config.STEM_MULTIPLIER,
            num_classes=num_classes,
            num_cells=config.NUM_CELLS,
            loss_criterion=loss_criterion,
            writer=self.writer)

        # Transfer model to GPU
        if torch.cuda.is_available():
            self.model = self.model.cuda()
            # Optimize if possible
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.enabled = True

        # Weights Optimizer
        w_optim = torch.optim.SGD(params=self.model.get_weights(),
                                  lr=config.WEIGHTS_LR,
                                  momentum=config.WEIGHTS_MOMENTUM,
                                  weight_decay=config.WEIGHTS_WEIGHT_DECAY)
        w_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            w_optim, config.EPOCHS, eta_min=config.WEIGHTS_LR_MIN)

        # gumbel softmax temperature scheduler
        temp_scheduler = None
        if config.USE_GUMBEL_SOFTMAX:
            temp_scheduler = LinearTempScheduler(
                n_epochs=config.EPOCHS,
                starting_temp=config.ALPHA_STARTING_TEMP,
                final_temp=config.ALPHA_MIN_TEMP)

        # Alpha Optimizer - one for each level
        alpha_optim = []
        # If trying to simulate DARTS don't bother with alpha optim for higher level
        if config.NUM_NODES_AT_LEVEL[0] == 2:
            num_levels = 1
        else:
            num_levels = config.NUM_LEVELS
        for level in range(0, num_levels):
            alpha_optim.append(
                torch.optim.Adam(params=self.model.get_alpha_level(level),
                                 lr=config.ALPHA_LR[level],
                                 weight_decay=config.ALPHA_WEIGHT_DECAY,
                                 betas=config.ALPHA_MOMENTUM))

        # If loading from checkpoint, replace modelController's model with checkpoint model
        start_epoch = 0

        if config.LOAD_FROM_CHECKPOINT is not None:
            self.model, w_optim, w_lr_scheduler, alpha_optim, start_epoch = load_checkpoint(
                self.model, w_optim, w_lr_scheduler, alpha_optim,
                os.path.join(config.CHECKPOINT_PATH,
                             config.LOAD_FROM_CHECKPOINT))
            print("Loaded Checkpoint:", config.LOAD_FROM_CHECKPOINT)

        # Training Loop
        best_top1 = 0.
        for epoch in range(start_epoch, config.EPOCHS):
            lr = w_lr_scheduler.get_lr()[0]
            print("W Learning Rate:", lr)

            # Attemp to get temperature and step temp scheduler
            temp = None
            if temp_scheduler is not None:
                temp = temp_scheduler.step()

            # Put into weight training mode - turn off gradient for alpha
            self.model.weight_training_mode()

            # Weight Training
            self.train_weights(train_loader=train_loader,
                               model=self.model,
                               w_optim=w_optim,
                               epoch=epoch,
                               lr=lr,
                               temp=temp)

            # GPU Memory Allocated for Model in Weight Sharing Phase
            if epoch == 0:
                try:
                    print(
                        "Weight Training Phase: Max GPU Memory Used",
                        torch.cuda.max_memory_allocated() /
                        (1024 * 1024 * 1024), "GB")
                except:
                    print("Unable to retrieve memory data")

            # Turn off gradient for weight params
            self.model.alpha_training_mode()

            # Alpha Training / Validation
            top1 = self.train_alpha(valid_loader=valid_loader,
                                    model=self.model,
                                    alpha_optim=alpha_optim,
                                    epoch=epoch,
                                    lr=lr,
                                    temp=temp)

            # Save Checkpoint
            if best_top1 < top1:
                best_top1 = top1
                is_best = True
            else:
                is_best = False
            print("Saving checkpoint")
            save_checkpoint(
                self.model, epoch, w_optim, w_lr_scheduler, alpha_optim,
                os.path.join(config.CHECKPOINT_PATH, self.dt_string), is_best)

            # Weight Learning Rate Step
            w_lr_scheduler.step()

            # GPU Memory Allocated for Model
            if epoch == 0:
                try:
                    print(
                        "Alpha Training: Max GPU Memory Used",
                        torch.cuda.max_memory_allocated() /
                        (1024 * 1024 * 1024), "GB")
                except:
                    print("Unable to print memory data")

        # Log Best Accuracy so far
        print("Final best Prec@1 = {:.4%}".format(best_top1))

        # Terminate
        self.terminate()

    def train_weights(self,
                      train_loader,
                      model: ModelController,
                      w_optim,
                      epoch,
                      lr,
                      temp=None):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_step = epoch * len(train_loader)

        # Log LR
        self.writer.add_scalar('train/lr', lr, epoch)

        # Prepares the model for training - 'training mode'
        model.train()

        for step, (trn_X, trn_y) in enumerate(train_loader):
            N = trn_X.size(0)
            if torch.cuda.is_available():
                trn_X = trn_X.cuda(non_blocking=True)
                trn_y = trn_y.cuda(non_blocking=True)

            # Weights Step
            w_optim.zero_grad()
            logits = model(trn_X, temp=temp)
            loss = model.loss_criterion(logits, trn_y)
            loss.backward()

            # gradient clipping
            nn.utils.clip_grad_norm_(model.get_weights(),
                                     config.WEIGHTS_GRADIENT_CLIP)
            w_optim.step()

            prec1, prec5 = accuracy(logits, trn_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.PRINT_STEP_FREQUENCY == 0 or step == len(
                    train_loader) - 1:
                print(
                    datetime.now(),
                    "Weight Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.EPOCHS,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            self.writer.add_scalar('train/loss', loss.item(), cur_step)
            self.writer.add_scalar('train/top1', prec1.item(), cur_step)
            self.writer.add_scalar('train/top5', prec5.item(), cur_step)
            cur_step += 1

        print("Weight Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, config.EPOCHS, top1.avg))

    def train_alpha(self,
                    valid_loader,
                    model: ModelController,
                    alpha_optim,
                    epoch,
                    lr,
                    temp=None):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        cur_step = epoch * len(valid_loader)

        # Log LR
        self.writer.add_scalar('train/lr', lr, epoch)

        # Prepares the model for training - 'training mode'
        model.train()

        for step, (val_X, val_y) in enumerate(valid_loader):
            N = val_X.size(0)
            if torch.cuda.is_available():
                val_X = val_X.cuda(non_blocking=True)
                val_y = val_y.cuda(non_blocking=True)

            # Alpha Gradient Steps for each level
            for level in range(len(alpha_optim)):
                alpha_optim[level].zero_grad()
            logits = model(val_X, temp=temp)
            loss = model.loss_criterion(logits, val_y)
            loss.backward()
            for level in range(len(alpha_optim)):
                alpha_optim[level].step()

            prec1, prec5 = accuracy(logits, val_y, topk=(1, 5))
            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)

            if step % config.PRINT_STEP_FREQUENCY == 0 or step == len(
                    valid_loader) - 1:
                print(
                    datetime.now(),
                    "Alpha Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.EPOCHS,
                        step,
                        len(valid_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            self.writer.add_scalar('val/loss', losses.avg, cur_step)
            self.writer.add_scalar('val/top1', top1.avg, cur_step)
            self.writer.add_scalar('val/top5', top5.avg, cur_step)
            cur_step += 1

        print(
            "Alpha Train (Uses Validation Loss): [{:2d}/{}] Final Prec@1 {:.4%}"
            .format(epoch + 1, config.EPOCHS, top1.avg))
        return top1.avg

    def validate(self, valid_loader, model, epoch, cur_step):
        top1 = AverageMeter()
        top5 = AverageMeter()
        losses = AverageMeter()

        model.eval()

        with torch.no_grad():
            for step, (X, y) in enumerate(valid_loader):
                N = X.size(0)

                logits = model(X)

                if torch.cuda.is_available():
                    y = y.cuda()

                loss = model.loss_criterion(logits, y)

                prec1, prec5 = accuracy(logits, y, topk=(1, 5))
                losses.update(loss.item(), N)
                top1.update(prec1.item(), N)
                top5.update(prec5.item(), N)

                if step % config.PRINT_STEP_FREQUENCY == 0 or step == len(
                        valid_loader) - 1:
                    print(
                        datetime.now(),
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            config.EPOCHS,
                            step,
                            len(valid_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

        self.writer.add_scalar('val/loss', losses.avg, cur_step)
        self.writer.add_scalar('val/top1', top1.avg, cur_step)
        self.writer.add_scalar('val/top5', top5.avg, cur_step)

        print("Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, config.EPOCHS, top1.avg))

        return top1.avg

    def terminate(self, signal=None, frame=None):
        # Print alpha
        print("Alpha Normal")
        print_alpha(self.model.alpha_normal)
        print("Alpha Reduce")
        print_alpha(self.model.alpha_reduce)

        # Pass exit signal on
        sys.exit(0)