Exemple #1
0
def extract_features(model, args):
    data_loader = load_data(args.datadir, args.crop_size, args.batch_size,
                            args.num_workers)
    new_model = nn.Sequential(*list(model.children())[:-1])

    dashboard = Dashboard()

    Y = np.empty((0))
    X = np.empty((0, 1024))
    #
    for i, (images, targets) in enumerate(data_loader['train']):
        images = Variable(images.float().cuda(), volatile=True)
        targets = Variable(targets.long().cuda(), volatile=True)
        outputs = new_model(images)
        outputs = F.avg_pool2d(outputs, kernel_size=7)
        X = np.vstack((X, outputs.data.cpu().numpy().squeeze()))
        Y = np.hstack((Y, targets.data.cpu().numpy()))

    # X = np.load('train.npy')
    # Y = X[:,0]
    # X = X[:, 1:]
    # X = np.nan_to_num(X)

    #---compute tsne rep. and plot them---
    X, Y = tsne(X, Y)
    dashboard.plot_tsne(X, Y)
Exemple #2
0
def get_last_eval_vals(opt, vis_opt, eval_key, algo, env_name, num_trials,
                       trial_offset, bin_size):
    # For each trial
    eval_vals = []
    for trial in range(trial_offset, trial_offset + num_trials):
        # Get the logpath
        logpath = os.path.join(opt['logs']['log_base'], opt['model']['mode'],
                               opt['logs']['exp_name'], algo, env_name,
                               'trial%d' % trial)
        if not os.path.isdir(logpath):
            return None

        # Create the dashboard object
        opt['env']['env-name'] = env_name
        opt['alg'] = opt['alg_%s' % algo]
        opt['optim'] = opt['optim_%s' % algo]
        opt['alg']['algo'] = algo
        opt['trial'] = trial
        dash = Dashboard(opt, vis_opt, logpath, vis=False)

        # Get data
        try:
            dash.preload_data()
            raw_x, raw_y = dash.load_data('episode_monitor', 'scalar',
                                          eval_key)
        except Exception:
            return None

        # Get data from last bin
        if not (len(raw_y) > bin_size):
            return None
        raw_vals = raw_y[-bin_size:]
        assert (len(raw_vals) == bin_size)
        raw_vals = [float(v) for v in raw_vals]
        raw_val = np.mean(raw_vals)
        eval_vals.append(raw_val)

    # Return
    return eval_vals
Exemple #3
0
def compute_confusion_matrix(model, args):
    data_loader = load_data(args.datadir, args.crop_size, args.batch_size,
                            args.num_workers)
    new_model = nn.Sequential(*list(model.children())[:-1])

    dashboard = Dashboard()

    T = np.empty((0))
    Y = np.empty((0))
    #
    for i, (images, targets) in enumerate(data_loader['train']):
        images = Variable(images.float().cuda(), volatile=True)
        targets = Variable(targets.long().cuda(), volatile=True)
        outputs = model(images)
        _, pred = outputs.topk(1, 1, True, True)
        pred = pred.t()

        T = np.hstack((T, targets.data.cpu().numpy()))
        Y = np.hstack((Y, pred.data.cpu().numpy().squeeze()))

    conf_mat = confusion_matrix(T, Y)
    dashboard.plot_conf_matr(conf_mat)
Exemple #4
0
def train(args, model, enc=False):
    best_acc = 0

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):
        weight[0] = 2.3653597831726	
        weight[1] = 4.4237880706787	
        weight[2] = 2.9691488742828	
        weight[3] = 5.3442072868347	
        weight[4] = 5.2983593940735	
        weight[5] = 5.2275490760803	
        weight[6] = 5.4394111633301	
        weight[7] = 5.3659925460815	
        weight[8] = 3.4170460700989	
        weight[9] = 5.2414722442627	
        weight[10] = 4.7376127243042	
        weight[11] = 5.2286224365234	
        weight[12] = 5.455126285553	
        weight[13] = 4.3019247055054	
        weight[14] = 5.4264230728149	
        weight[15] = 5.4331531524658	
        weight[16] = 5.433765411377	
        weight[17] = 5.4631009101868	
        weight[18] = 5.3947434425354
    else:
        weight[0] = 2.8149201869965	
        weight[1] = 6.9850029945374	
        weight[2] = 3.7890393733978	
        weight[3] = 9.9428062438965	
        weight[4] = 9.7702074050903	
        weight[5] = 9.5110931396484	
        weight[6] = 10.311357498169	
        weight[7] = 10.026463508606	
        weight[8] = 4.6323022842407	
        weight[9] = 9.5608062744141	
        weight[10] = 7.8698215484619	
        weight[11] = 9.5168733596802	
        weight[12] = 10.373730659485	
        weight[13] = 6.6616044044495	
        weight[14] = 10.260489463806	
        weight[15] = 10.287888526917	
        weight[16] = 10.289801597595	
        weight[17] = 10.405355453491	
        weight[18] = 10.138095855713	

    weight[19] = 0

    assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)#1024)
    co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#1024)
    dataset_train = cityscapes(args.datadir, co_transform, 'train',50)
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val',100)
    print(len(dataset_train))
    loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
#     print(list(enumerate(loader)))
    if args.cuda:
        weight = weight.cuda()
    criterion = CrossEntropyLoss2d(weight)

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"    

    if (not os.path.exists(automated_log_path)):    #dont add first line if it exists 
        with open(automated_log_path, "a") as myfile:
            myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate")

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))


    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893

    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)      ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value. 
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)                             ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs+1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)    ## scheduler 2

        epoch_loss = []
        time_train = []
     
        doIouTrain = args.iouTrain   
        doIouVal =  args.iouVal      

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        #print("this is me!!!!!")
        #print(len(loader))
        for step, (images, labels) in enumerate(loader):

            start_time = time.time()
            #print("this is also m")
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))
            #print("This is me on traget")
            #print(np.min(targets.cpu().detach().numpy()))
            #print("This is me after target")
            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            #print("This is me on loss")
            #print(loss)
            #print("This is me after loss")
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.cpu().detach().numpy().item())
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                iouEvalTrain.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)      

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size))

            
        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
        
        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m'
            print ("EPOCH IoU on TRAIN set: ", iouStr, "%")  

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES)

        for step, (images, labels) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images, volatile=True)    #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)
            outputs = model(inputs, only_encode=enc) 

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.cpu().detach().numpy().item())
            time_val.append(time.time() - start_time)


            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                #start_time_iou = time.time()
                iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'VAL target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size))
                       

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
            print ("EPOCH IoU on VAL set: ", iouStr, "%") 
           

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal 
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'    
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model),
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))   
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))           

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr ))
    
    return(model)   #return model (convenience for encoder-decoder training)
Exemple #5
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    model_opt['time_scale'] = env_opt['time_scale']
    if model_opt['mode'] in ['baselinewtheta', 'phasewtheta']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']
        model_opt['theta_sz'] = env_opt['theta_sz']
    elif model_opt['mode'] in ['baseline_lowlevel', 'phase_lowlevel']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']

    # Check asserts
    assert (model_opt['mode'] in [
        'baseline', 'baseline_reverse', 'phasesimple', 'phasewstate',
        'baselinewtheta', 'phasewtheta', 'baseline_lowlevel', 'phase_lowlevel',
        'interpolate', 'cyclic', 'maze_baseline', 'maze_baseline_wphase'
    ])
    assert (args.algo in ['a2c', 'ppo', 'acktr'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'

    # Set seed - just make the seed the trial number
    seed = args.trial
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']
                      ) // alg_opt['num_steps'] // alg_opt['num_processes']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # If interpolate mode, choose states
    if options['model']['mode'] == 'phase_lowlevel' and options['env'][
            'theta_space_mode'] == 'pretrain_interp':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = random.choice(all_states)
        s2 = random.choice(all_states)
        fixed_states = [s1, s2]
    elif model_opt['mode'] == 'interpolate':
        all_states = torch.load(env_opt['saved_state_file'])
        s1 = all_states[env_opt['s1_ind']]
        s2 = all_states[env_opt['s2_ind']]
        fixed_states = [s1, s2]
    else:
        fixed_states = None

    # Create environments
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose,
                 fixed_states) for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Get theta_sz for models (if applicable)
    dummy_env.reset()
    if model_opt['mode'] == 'baseline_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.theta_sz
    elif model_opt['mode'] == 'phase_lowlevel':
        model_opt['theta_sz'] = dummy_env.env.env.theta_sz
    if 'theta_sz' in model_opt:
        env_opt['theta_sz'] = model_opt['theta_sz']

    # Get observation shape
    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0] * env_opt['num_stack'], *obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    if len(envs.observation_space.shape) == 1:
        ignore_mask = np.zeros(envs.observation_space.shape)
        if env_opt['add_timestep']:
            ignore_mask[-1] = 1
        if model_opt['mode'] in [
                'baselinewtheta', 'phasewtheta', 'baseline_lowlevel',
                'phase_lowlevel'
        ]:
            theta_sz = env_opt['theta_sz']
            if env_opt['add_timestep']:
                ignore_mask[-(theta_sz + 1):] = 1
            else:
                ignore_mask[-theta_sz:] = 1
        if args.finetune_baseline:
            ignore_mask = dummy_env.unwrapped._get_obs_mask()
            freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
            if env_opt['add_timestep']:
                ignore_mask = np.concatenate([ignore_mask, [1]])
                freeze_mask = np.concatenate([freeze_mask, [0]])
            ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=True,
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     freeze_mask=freeze_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])
        else:
            envs = ObservationFilter(envs,
                                     ret=alg_opt['norm_ret'],
                                     has_timestep=env_opt['add_timestep'],
                                     noclip=env_opt['step_plus_noclip'],
                                     ignore_mask=ignore_mask,
                                     time_scale=env_opt['time_scale'],
                                     gamma=env_opt['gamma'])

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()

    # Create the policy network
    actor_critic = Policy(obs_shape, envs.action_space, model_opt)
    if args.cuda:
        actor_critic.cuda()

    # Create the agent
    if envs.action_space.__class__.__name__ == "Discrete":
        action_shape = 1
    else:
        action_shape = envs.action_space.shape[0]

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               lr=optim_opt['lr'],
                               eps=optim_opt['eps'],
                               alpha=optim_opt['alpha'],
                               max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         alg_opt['clip_param'],
                         alg_opt['ppo_epoch'],
                         alg_opt['num_mini_batch'],
                         alg_opt['value_loss_coef'],
                         alg_opt['entropy_coef'],
                         lr=optim_opt['lr'],
                         eps=optim_opt['eps'],
                         max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'],
                               acktr=True)
    rollouts = RolloutStorage(alg_opt['num_steps'], alg_opt['num_processes'],
                              obs_shape, envs.action_space,
                              actor_critic.state_size)
    current_obs = torch.zeros(alg_opt['num_processes'], *obs_shape)

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        agent.load_state_dict(ckpt['agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif len(args.other_resume) > 0:
        ckpt = torch.load(args.other_resume)

        # This should update both the policy network
        agent.actor_critic.load_state_dict(ckpt['agent']['model'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    elif args.finetune_baseline:
        # Load the model based on the trial number
        ckpt_base = options['lowlevel']['ckpt']
        ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % args.trial
        ckpt = torch.load(ckpt_file)

        # Make "input mask" that tells the model which inputs were the same from before and should be copied
        oldinput_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()

        # This should update both the policy network
        agent.actor_critic.load_state_dict_special(ckpt['agent']['model'],
                                                   oldinput_mask)

        # Set ob_rms
        old_rms = ckpt['ob_rms']
        old_size = old_rms.mean.size
        if env_opt['add_timestep']:
            old_size -= 1

        # Only copy the pro state part of it
        envs.ob_rms.mean[:old_size] = old_rms.mean[:old_size]
        envs.ob_rms.var[:old_size] = old_rms.var[:old_size]

    # Inline define our helper function for updating obs
    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    # Reset our env and rollouts
    obs = envs.reset()
    update_current_obs(obs)
    rollouts.observations[0].copy_(current_obs)
    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step], rollouts.states[step],
                    rollouts.masks[step])
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Observe reward and next obs
            obs, reward, done, info = envs.step(cpu_actions)
            #pdb.set_trace()
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()
            episode_rewards += reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            final_rewards *= masks
            final_rewards += (1 - masks) * episode_rewards
            episode_rewards *= masks

            if args.cuda:
                masks = masks.cuda()

            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks

            update_current_obs(obs)
            rollouts.insert(current_obs, states, action, action_log_prob,
                            value, reward, masks)

        # Update model and rollouts
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()
        rollouts.compute_returns(next_value, alg_opt['use_gae'],
                                 env_opt['gamma'], alg_opt['gae_tau'])
        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        # Add algo updates here
        alg_info = {}
        alg_info['value_loss'] = value_loss
        alg_info['action_loss'] = action_loss
        alg_info['dist_entropy'] = dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j +
                           1) * alg_opt['num_processes'] * alg_opt['num_steps']
        #save_interval = log_opt['save_interval'] * alg_opt['log_mult']
        save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            save_checkpoint(logpath,
                            agent,
                            envs,
                            j,
                            total_num_steps,
                            args.save_every,
                            final=False)

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(), dist_entropy,
                        value_loss, action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath,
                    agent,
                    envs,
                    j,
                    total_num_steps,
                    args.save_every,
                    final=False)

    # Close logging file
    alg_f.close()
Exemple #6
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Load the lowlevel opt and
    lowlevel_optfile = options['lowlevel']['optfile']
    with open(lowlevel_optfile, 'r') as handle:
        ll_opt = yaml.load(handle)

    # Whether we should set ll policy to be deterministic or not
    ll_deterministic = options['lowlevel']['deterministic']

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = args.trial
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    options[
        'lowlevel_opt'] = ll_opt  # Save low level options in option file (for logging purposes)

    # Pass necessary values in ll_opt
    assert (ll_opt['model']['mode'] in ['baseline_lowlevel', 'phase_lowlevel'])
    ll_opt['model']['theta_space_mode'] = ll_opt['env']['theta_space_mode']
    ll_opt['model']['time_scale'] = ll_opt['env']['time_scale']

    # If in many module mode, load the lowlevel policies we want
    if model_opt['mode'] == 'hierarchical_many':
        # Check asserts
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_space_mode in [
            'pretrain_interp', 'pretrain_any', 'pretrain_any_far',
            'pretrain_any_fromstart'
        ])
        assert (theta_obs_mode == 'pretrain')

        # Get the theta size
        theta_sz = options['lowlevel']['num_load']
        ckpt_base = options['lowlevel']['ckpt']

        # Load checkpoints
        lowlevel_ckpts = []
        for ll_ind in range(theta_sz):
            if args.change_ll_offset:
                ll_offset = theta_sz * args.trial
            else:
                ll_offset = 0
            lowlevel_ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % (
                ll_ind + ll_offset)
            assert (os.path.isfile(lowlevel_ckpt_file))
            lowlevel_ckpts.append(torch.load(lowlevel_ckpt_file))

    # Otherwise it's one ll polciy to load
    else:
        # Get theta_sz for low level model
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert (theta_obs_mode in ['ind', 'vector'])
        if theta_obs_mode == 'ind':
            if theta_space_mode == 'forward':
                theta_sz = 1
            elif theta_space_mode == 'simple_four':
                theta_sz = 4
            elif theta_space_mode == 'simple_eight':
                theta_sz = 8
            elif theta_space_mode == 'k_theta':
                theta_sz = ll_opt['env']['num_theta']
            elif theta_obs_mode == 'vector':
                theta_sz = 2
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError
        ll_opt['model']['theta_sz'] = theta_sz
        ll_opt['env']['theta_sz'] = theta_sz

        # Load the low level policy params
        lowlevel_ckpt = options['lowlevel']['ckpt']
        assert (os.path.isfile(lowlevel_ckpt))
        lowlevel_ckpt = torch.load(lowlevel_ckpt)
    hl_action_space = spaces.Discrete(theta_sz)

    # Check asserts
    assert (args.algo in ['a2c', 'ppo', 'acktr', 'dqn'])
    assert (optim_opt['hierarchical_mode']
            in ['train_highlevel', 'train_both'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'
    assert (model_opt['mode'] in ['hierarchical', 'hierarchical_many'])

    # Set seed - just make the seed the trial number
    seed = args.trial + 1000  # Make it different than lowlevel seed
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    num_updates = int(optim_opt['num_frames']) // alg_opt[
        'num_steps'] // alg_opt['num_processes'] // optim_opt['num_ll_steps']
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print(
        "WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards"
    )
    print("#######")

    # Set logging / load previous checkpoint
    logpath = os.path.join(log_opt['log_base'], model_opt['mode'],
                           log_opt['exp_name'], args.algo, args.env_name,
                           'trial%d' % args.trial)
    if len(args.resume) > 0:
        assert (os.path.isfile(os.path.join(logpath, args.resume)))
        ckpt = torch.load(os.path.join(logpath, 'ckpt.pth.tar'))
        start_update = ckpt['update_count']
    else:
        # Make directory, check before overwriting
        if os.path.isdir(logpath):
            if click.confirm(
                    'Logs directory already exists in {}. Erase?'.format(
                        logpath, default=False)):
                os.system('rm -rf ' + logpath)
            else:
                return
        os.system('mkdir -p ' + logpath)
        start_update = 0

        # Save options and args
        with open(os.path.join(logpath, os.path.basename(args.path_opt)),
                  'w') as f:
            yaml.dump(options, f, default_flow_style=False)
        with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
            yaml.dump(vars(args), f, default_flow_style=False)

        # Save git info as well
        os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
        os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
        os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # Create environments
    envs = [
        make_env(args.env_name, seed, i, logpath, options, args.verbose)
        for i in range(alg_opt['num_processes'])
    ]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Check if we use timestep in low level
    if 'baseline' in ll_opt['model']['mode']:
        add_timestep = False
    elif 'phase' in ll_opt['model']['mode']:
        add_timestep = True
    else:
        raise NotImplementedError

    # Get shapes
    dummy_env = make_env(args.env_name, seed, 0, logpath, options,
                         args.verbose)
    dummy_env = dummy_env()
    s_pro_dummy = dummy_env.unwrapped._get_pro_obs()
    s_ext_dummy = dummy_env.unwrapped._get_ext_obs()
    if add_timestep:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz + 1, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0] + 1, )
    else:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz, )
        ll_raw_obs_shape = (s_pro_dummy.shape[0], )
    ll_obs_shape = (ll_obs_shape[0] * env_opt['num_stack'], *ll_obs_shape[1:])
    hl_obs_shape = (s_ext_dummy.shape[0], )
    hl_obs_shape = (hl_obs_shape[0] * env_opt['num_stack'], *hl_obs_shape[1:])

    # Do vec normalize, but mask out what we don't want altered
    # Also freeze all of the low level obs
    ignore_mask = dummy_env.env._get_obs_mask()
    freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
    freeze_mask = np.concatenate([freeze_mask, [0]])
    if ('normalize' in env_opt
            and not env_opt['normalize']) or args.algo == 'dqn':
        ignore_mask = 1 - freeze_mask
    if model_opt['mode'] == 'hierarchical_many':
        # Actually ignore both ignored values and the low level values
        # That filtering will happen later
        ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])
    else:
        envs = ObservationFilter(envs,
                                 ret=alg_opt['norm_ret'],
                                 has_timestep=True,
                                 noclip=env_opt['step_plus_noclip'],
                                 ignore_mask=ignore_mask,
                                 freeze_mask=freeze_mask,
                                 time_scale=env_opt['time_scale'],
                                 gamma=env_opt['gamma'])

    # Make our helper object for dealing with hierarchical observations
    hier_utils = HierarchyUtils(ll_obs_shape, hl_obs_shape, hl_action_space,
                                theta_sz, add_timestep)

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n' %
                json.dumps({
                    "t_start": time.time(),
                    'env_id': dummy_env.spec and dummy_env.spec.id,
                    'mode': options['model']['mode'],
                    'name': options['logs']['exp_name']
                }))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()
    ll_alg_filename = os.path.join(logpath, 'AlgLL.Monitor.csv')
    ll_alg_f = open(ll_alg_filename, "wt")
    ll_alg_f.write('# Alg Logging LL %s\n' %
                   json.dumps({
                       "t_start": time.time(),
                       'env_id': dummy_env.spec and dummy_env.spec.id,
                       'mode': options['model']['mode'],
                       'name': options['logs']['exp_name']
                   }))
    ll_alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    ll_alg_logger = csv.DictWriter(ll_alg_f, fieldnames=ll_alg_fields)
    ll_alg_logger.writeheader()
    ll_alg_f.flush()

    # Create the policy networks
    ll_action_space = envs.action_space
    if args.algo == 'dqn':
        model_opt['eps_start'] = optim_opt['eps_start']
        model_opt['eps_end'] = optim_opt['eps_end']
        model_opt['eps_decay'] = optim_opt['eps_decay']
        hl_policy = DQNPolicy(hl_obs_shape, hl_action_space, model_opt)
    else:
        hl_policy = Policy(hl_obs_shape, hl_action_space, model_opt)
    if model_opt['mode'] == 'hierarchical_many':
        ll_policy = ModularPolicy(ll_raw_obs_shape, ll_action_space, theta_sz,
                                  ll_opt)
    else:
        ll_policy = Policy(ll_obs_shape, ll_action_space, ll_opt['model'])
    # Load the previous ones here?
    if args.cuda:
        hl_policy.cuda()
        ll_policy.cuda()

    # Create the high level agent
    if args.algo == 'a2c':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  lr=optim_opt['lr'],
                                  eps=optim_opt['eps'],
                                  alpha=optim_opt['alpha'],
                                  max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        hl_agent = algo.PPO(hl_policy,
                            alg_opt['clip_param'],
                            alg_opt['ppo_epoch'],
                            alg_opt['num_mini_batch'],
                            alg_opt['value_loss_coef'],
                            alg_opt['entropy_coef'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        hl_agent = algo.A2C_ACKTR(hl_policy,
                                  alg_opt['value_loss_coef'],
                                  alg_opt['entropy_coef'],
                                  acktr=True)
    elif args.algo == 'dqn':
        hl_agent = algo.DQN(hl_policy,
                            env_opt['gamma'],
                            batch_size=alg_opt['batch_size'],
                            target_update=alg_opt['target_update'],
                            mem_capacity=alg_opt['mem_capacity'],
                            lr=optim_opt['lr'],
                            eps=optim_opt['eps'],
                            max_grad_norm=optim_opt['max_grad_norm'])

    # Create the low level agent
    # If only training high level, make dummy agent (just does passthrough, doesn't change anything)
    if optim_opt['hierarchical_mode'] == 'train_highlevel':
        ll_agent = algo.Passthrough(ll_policy)
    elif optim_opt['hierarchical_mode'] == 'train_both':
        if args.algo == 'a2c':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      lr=optim_opt['ll_lr'],
                                      eps=optim_opt['eps'],
                                      alpha=optim_opt['alpha'],
                                      max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'ppo':
            ll_agent = algo.PPO(ll_policy,
                                alg_opt['clip_param'],
                                alg_opt['ll_ppo_epoch'],
                                alg_opt['num_mini_batch'],
                                alg_opt['value_loss_coef'],
                                alg_opt['entropy_coef'],
                                lr=optim_opt['ll_lr'],
                                eps=optim_opt['eps'],
                                max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'acktr':
            ll_agent = algo.A2C_ACKTR(ll_policy,
                                      alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'],
                                      acktr=True)
    else:
        raise NotImplementedError

    # Make the rollout structures
    hl_rollouts = RolloutStorage(alg_opt['num_steps'],
                                 alg_opt['num_processes'], hl_obs_shape,
                                 hl_action_space, hl_policy.state_size)
    ll_rollouts = MaskingRolloutStorage(alg_opt['num_steps'],
                                        alg_opt['num_processes'], ll_obs_shape,
                                        ll_action_space, ll_policy.state_size)
    hl_current_obs = torch.zeros(alg_opt['num_processes'], *hl_obs_shape)
    ll_current_obs = torch.zeros(alg_opt['num_processes'], *ll_obs_shape)

    # Helper functions to update the current obs
    def update_hl_current_obs(obs):
        shape_dim0 = hl_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            hl_current_obs[:, :-shape_dim0] = hl_current_obs[:, shape_dim0:]
        hl_current_obs[:, -shape_dim0:] = obs

    def update_ll_current_obs(obs):
        shape_dim0 = ll_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            ll_current_obs[:, :-shape_dim0] = ll_current_obs[:, shape_dim0:]
        ll_current_obs[:, -shape_dim0:] = obs

    # Update agent with loaded checkpoint
    if len(args.resume) > 0:
        # This should update both the policy network and the optimizer
        ll_agent.load_state_dict(ckpt['ll_agent'])
        hl_agent.load_state_dict(ckpt['hl_agent'])

        # Set ob_rms
        envs.ob_rms = ckpt['ob_rms']
    else:
        if model_opt['mode'] == 'hierarchical_many':
            ll_agent.load_pretrained_policies(lowlevel_ckpts)
        else:
            # Load low level agent
            ll_agent.load_state_dict(lowlevel_ckpt['agent'])

            # Load ob_rms from low level (but need to reshape it)
            old_rms = lowlevel_ckpt['ob_rms']
            assert (old_rms.mean.shape[0] == ll_obs_shape[0])
            # Only copy the pro state part of it (not including thetas or count)
            envs.ob_rms.mean[:s_pro_dummy.
                             shape[0]] = old_rms.mean[:s_pro_dummy.shape[0]]
            envs.ob_rms.var[:s_pro_dummy.shape[0]] = old_rms.var[:s_pro_dummy.
                                                                 shape[0]]

    # Reset our env and rollouts
    raw_obs = envs.reset()
    hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
    ll_obs = hier_utils.placeholder_theta(raw_ll_obs, step_counts)
    update_hl_current_obs(hl_obs)
    update_ll_current_obs(ll_obs)
    hl_rollouts.observations[0].copy_(hl_current_obs)
    ll_rollouts.observations[0].copy_(ll_current_obs)
    ll_rollouts.recent_obs.copy_(ll_current_obs)
    if args.cuda:
        hl_current_obs = hl_current_obs.cuda()
        ll_current_obs = ll_current_obs.cuda()
        hl_rollouts.cuda()
        ll_rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros([alg_opt['num_processes'], 1])
    final_rewards = torch.zeros([alg_opt['num_processes'], 1])

    # Update loop
    start = time.time()
    for j in range(start_update, num_updates):
        for step in range(alg_opt['num_steps']):
            # Step through high level action
            start_time = time.time()
            with torch.no_grad():
                hl_value, hl_action, hl_action_log_prob, hl_states = hl_policy.act(
                    hl_rollouts.observations[step], hl_rollouts.states[step],
                    hl_rollouts.masks[step])
            hl_cpu_actions = hl_action.squeeze(1).cpu().numpy()
            if args.profile:
                print('hl act %f' % (time.time() - start_time))

            # Get values to use for Q learning
            hl_state_dqn = hl_rollouts.observations[step]
            hl_action_dqn = hl_action

            # Update last ll observation with new theta
            for proc in range(alg_opt['num_processes']):
                # Update last observations in memory
                last_obs = ll_rollouts.observations[ll_rollouts.steps[proc],
                                                    proc]
                if hier_utils.has_placeholder(last_obs):
                    new_last_obs = hier_utils.update_theta(
                        last_obs, hl_cpu_actions[proc])
                    ll_rollouts.observations[ll_rollouts.steps[proc],
                                             proc].copy_(new_last_obs)

                # Update most recent observations (not necessarily the same)
                assert (hier_utils.has_placeholder(
                    ll_rollouts.recent_obs[proc]))
                new_last_obs = hier_utils.update_theta(
                    ll_rollouts.recent_obs[proc], hl_cpu_actions[proc])
                ll_rollouts.recent_obs[proc].copy_(new_last_obs)
            assert (ll_rollouts.observations.max().item() < float('inf')
                    and ll_rollouts.recent_obs.max().item() < float('inf'))

            # Given high level action, step through the low level actions
            death_step_mask = np.ones([alg_opt['num_processes'],
                                       1])  # 1 means still alive, 0 means dead
            hl_reward = torch.zeros([alg_opt['num_processes'], 1])
            hl_obs = [None for i in range(alg_opt['num_processes'])]
            for ll_step in range(optim_opt['num_ll_steps']):
                # Sample actions
                start_time = time.time()
                with torch.no_grad():
                    ll_value, ll_action, ll_action_log_prob, ll_states = ll_policy.act(
                        ll_rollouts.recent_obs,
                        ll_rollouts.recent_s,
                        ll_rollouts.recent_masks,
                        deterministic=ll_deterministic)
                ll_cpu_actions = ll_action.squeeze(1).cpu().numpy()
                if args.profile:
                    print('ll act %f' % (time.time() - start_time))

                # Observe reward and next obs
                raw_obs, ll_reward, done, info = envs.step(
                    ll_cpu_actions, death_step_mask)
                raw_hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(
                    raw_obs)
                ll_obs = []
                for proc in range(alg_opt['num_processes']):
                    if (ll_step
                            == optim_opt['num_ll_steps'] - 1) or done[proc]:
                        ll_obs.append(
                            hier_utils.placeholder_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([step_counts[proc]])))
                    else:
                        ll_obs.append(
                            hier_utils.append_theta(
                                np.array([raw_ll_obs[proc]]),
                                np.array([hl_cpu_actions[proc]]),
                                np.array([step_counts[proc]])))
                ll_obs = np.concatenate(ll_obs, 0)
                ll_reward = torch.from_numpy(
                    np.expand_dims(np.stack(ll_reward), 1)).float()
                episode_rewards += ll_reward
                hl_reward += ll_reward

                # Update values for Q learning and update replay memory
                time.time()
                hl_next_state_dqn = torch.from_numpy(raw_hl_obs)
                hl_reward_dqn = ll_reward
                hl_isdone_dqn = done
                if args.algo == 'dqn':
                    hl_agent.update_memory(hl_state_dqn, hl_action_dqn,
                                           hl_next_state_dqn, hl_reward_dqn,
                                           hl_isdone_dqn, death_step_mask)
                hl_state_dqn = hl_next_state_dqn
                if args.profile:
                    print('dqn memory %f' % (time.time() - start_time))

                # Update high level observations (only take most recent obs if we haven't see a done before now and thus the value is valid)
                for proc, raw_hl in enumerate(raw_hl_obs):
                    if death_step_mask[proc].item() > 0:
                        hl_obs[proc] = np.array([raw_hl])

                # If done then clean the history of observations
                masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                           for done_ in done])
                final_rewards *= masks
                final_rewards += (
                    1 - masks
                ) * episode_rewards  # TODO - actually not sure if I broke this logic, but this value is not used anywhere
                episode_rewards *= masks

                # TODO - I commented this out, which possibly breaks things if num_stack > 1. Fix later if necessary
                #if args.cuda:
                #    masks = masks.cuda()
                #if current_obs.dim() == 4:
                #    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                #else:
                #    current_obs *= masks

                # Update low level observations
                update_ll_current_obs(ll_obs)

                # Update low level rollouts
                ll_rollouts.insert(ll_current_obs, ll_states, ll_action,
                                   ll_action_log_prob, ll_value, ll_reward,
                                   masks, death_step_mask)

                # Update which ones have stepped to the end and shouldn't be updated next time in the loop
                death_step_mask *= masks

            # Update high level rollouts
            hl_obs = np.concatenate(hl_obs, 0)
            update_hl_current_obs(hl_obs)
            hl_rollouts.insert(hl_current_obs, hl_states, hl_action,
                               hl_action_log_prob, hl_value, hl_reward, masks)

            # Check if we want to update lowlevel policy
            if ll_rollouts.isfull and all([
                    not hier_utils.has_placeholder(
                        ll_rollouts.observations[ll_rollouts.steps[proc],
                                                 proc])
                    for proc in range(alg_opt['num_processes'])
            ]):
                # Update low level policy
                assert (ll_rollouts.observations.max().item() < float('inf'))
                if optim_opt['hierarchical_mode'] == 'train_both':
                    with torch.no_grad():
                        ll_next_value = ll_policy.get_value(
                            ll_rollouts.observations[-1],
                            ll_rollouts.states[-1],
                            ll_rollouts.masks[-1]).detach()
                    ll_rollouts.compute_returns(ll_next_value,
                                                alg_opt['use_gae'],
                                                env_opt['gamma'],
                                                alg_opt['gae_tau'])
                    ll_value_loss, ll_action_loss, ll_dist_entropy = ll_agent.update(
                        ll_rollouts)
                else:
                    ll_value_loss = 0
                    ll_action_loss = 0
                    ll_dist_entropy = 0
                ll_rollouts.after_update()

                # Update logger
                alg_info = {}
                alg_info['value_loss'] = ll_value_loss
                alg_info['action_loss'] = ll_action_loss
                alg_info['dist_entropy'] = ll_dist_entropy
                ll_alg_logger.writerow(alg_info)
                ll_alg_f.flush()

        # Update high level policy
        start_time = time.time()
        assert (hl_rollouts.observations.max().item() < float('inf'))
        if args.algo == 'dqn':
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                alg_opt['updates_per_step']
            )  # TODO - maybe log this loss properly
        else:
            with torch.no_grad():
                hl_next_value = hl_policy.get_value(
                    hl_rollouts.observations[-1], hl_rollouts.states[-1],
                    hl_rollouts.masks[-1]).detach()
            hl_rollouts.compute_returns(hl_next_value, alg_opt['use_gae'],
                                        env_opt['gamma'], alg_opt['gae_tau'])
            hl_value_loss, hl_action_loss, hl_dist_entropy = hl_agent.update(
                hl_rollouts)
        hl_rollouts.after_update()
        if args.profile:
            print('hl update %f' % (time.time() - start_time))

        # Update alg monitor for high level
        alg_info = {}
        alg_info['value_loss'] = hl_value_loss
        alg_info['action_loss'] = hl_action_loss
        alg_info['dist_entropy'] = hl_dist_entropy
        alg_logger.writerow(alg_info)
        alg_f.flush()

        # Save checkpoints
        total_num_steps = (j + 1) * alg_opt['num_processes'] * alg_opt[
            'num_steps'] * optim_opt['num_ll_steps']
        if 'save_interval' in alg_opt:
            save_interval = alg_opt['save_interval']
        else:
            save_interval = 100
        if j % save_interval == 0:
            # Save all of our important information
            start_time = time.time()
            save_checkpoint(logpath, ll_agent, hl_agent, envs, j,
                            total_num_steps)
            if args.profile:
                print('save checkpoint %f' % (time.time() - start_time))

        # Print log
        log_interval = log_opt['log_interval'] * alg_opt['log_mult']
        if j % log_interval == 0:
            end = time.time()
            print(
                "{}: Updates {}, num timesteps {}, FPS {}, mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}, entropy {:.5f}, value loss {:.5f}, policy loss {:.5f}"
                .format(options['logs']['exp_name'], j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        final_rewards.mean(), final_rewards.median(),
                        final_rewards.min(), final_rewards.max(),
                        hl_dist_entropy, hl_value_loss, hl_action_loss))

        # Do dashboard logging
        vis_interval = log_opt['vis_interval'] * alg_opt['log_mult']
        if args.vis and j % vis_interval == 0:
            try:
                # Sometimes monitor doesn't properly flush the outputs
                dashboard.visdom_plot()
            except IOError:
                pass

    # Save final checkpoint
    save_checkpoint(logpath, ll_agent, hl_agent, envs, j, total_num_steps)

    # Close logging file
    alg_f.close()
    ll_alg_f.close()
Exemple #7
0
import os

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

import medicalDataLoader
from criterion import CrossEntropyLoss2d
from enet import Enet
from utils import pred2segmentation, Colorize
from visualize import Dashboard
from  utils import show_image_mask
from pretrain_network import pretrain

board_image = Dashboard(server='http://turing.livia.etsmtl.ca', env="ADMM_image")
board_loss = Dashboard(server='http://turing.livia.etsmtl.ca', env="ADMM_loss")

use_gpu = True
# device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
device =torch.device('cuda')

batch_size = 1
batch_size_val = 1
num_workers = 0
lr = 0.001
max_epoch = 100
root_dir = '../ACDC-2D-All'
model_dir = 'model'
size_min = 5
size_max = 20
Exemple #8
0
def train(args, model, enc=False):
    best_acc = 0

    #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values)
    #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing

    weight = torch.ones(NUM_CLASSES)
    if (enc):        
        weight[0] = 4.38133159
        weight[1] = 1.29574148
    else:
        weight[0] = 4.40513628
        weight[1] = 1.293674
        
    if (enc):
        up = torch.nn.Upsample(scale_factor=16, mode='bilinear')
    else:
        up = torch.nn.Upsample(scale_factor=2, mode='bilinear')
        
    if args.cuda:
        up = up.cuda()

    assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(enc, augment=True, height=args.height)#1024)
    co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#1024)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    if args.cuda:
        weight = weight.cuda()
  
    if args.weighted:
        criterion = CrossEntropyLoss2d(weight)
    else:            
        criterion = CrossEntropyLoss2d()
        
    print(type(criterion))

    savedir = args.savedir

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"    

    if (not os.path.exists(automated_log_path)):    #dont add first line if it exists 
        with open(automated_log_path, "a") as myfile:
            myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate")

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))


    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893

    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=1e-4)      ## scheduler 2

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value. 
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)                             ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs+1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)    ## scheduler 2

        epoch_loss = []
        time_train = []
     
        doIouTrain = args.iouTrain   
        doIouVal =  args.iouVal      

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES, args.ignoreindex)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels, images_orig, labels_orig) in enumerate(loader):

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                upsampledOutputs = up(outputs)
                iouEvalTrain.addBatch(upsampledOutputs.max(1)[1].unsqueeze(1).data, labels_orig)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)      

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size))

            
        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
        
        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m'
            print ("EPOCH IoU on TRAIN set: ", iouStr, "%", iou_classes)  

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES, args.ignoreindex)

        for step, (images, labels, images_orig, labels_orig) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images, volatile=True)    #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)
            outputs = model(inputs, only_encode=enc) 

            loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.data[0])
            time_val.append(time.time() - start_time)


            #Add batch to calculate TP, FP and FN for iou estimation
            if (doIouVal):
                #start_time_iou = time.time()
                upsampledOutputs = up(outputs)
                iouEvalVal.addBatch(upsampledOutputs.max(1)[1].unsqueeze(1).data, labels_orig)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):   #merge gpu tensors
                    board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                else:
                    board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                    f'VAL output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                    f'VAL target (epoch: {epoch}, step: {step})')
                print ("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})', 
                        "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size))
                       

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
            print ("EPOCH IoU on VAL set: ", iouStr, "%", iou_classes) 
           

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = -average_epoch_loss_val
        else:
            current_acc = iouVal 
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'    
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model),
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))   
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))           

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr ))
    
    return(model)   #return model (convenience for encoder-decoder training)
def train(args, model_student, model_teacher, enc=False):
    global best_acc

    weight = torch.ones(1)
    
    assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded"

    # Set data loading variables
    co_transform = MyCoTransform(enc, augment=True, height=480)#1024)
    co_transform_val = MyCoTransform(enc, augment=False, height=480)#1024)
    dataset_train = self_supervised_power(args.datadir, co_transform, 'train', file_format="csv", label_name="class", subsample=args.subsample)
    # dataset_train = self_supervised_power(args.datadir, None, 'train')
    dataset_val = self_supervised_power(args.datadir, None, 'val', file_format="csv", label_name="class", subsample=args.subsample)

    if args.force_n_classes > 0:
        color_transform_classes_prob = ColorizeClassesProb(args.force_n_classes)  # Automatic color based on max class probability
        color_transform_classes = ColorizeClasses(args.force_n_classes)  # Automatic color based on max class probability

    loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
    loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    if args.cuda:
        weight = weight.cuda()

    # Set Loss functions
    if args.force_n_classes > 0:
        criterion = L1LossClassProbMasked() # L1 loss weighted with class prob with averaging over mini-batch
    else:
        criterion = L1LossMasked()     

    criterion = CrossEntropyLoss2d()

    criterion_trav = L1LossTraversability()
    criterion_consistency = MSELossWeighted()
    criterion_val = CrossEntropyLoss2d()    
    criterion_acc = ClassificationAccuracy()
    print(type(criterion))

    savedir = f'../save/{args.savedir}'

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"    

    if (not os.path.exists(automated_log_path)):    #dont add first line if it exists 
        with open(automated_log_path, "a") as myfile:
            myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate")

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model_student))


    #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4        #https://github.com/pytorch/pytorch/issues/1893

    #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999),  eps=1e-08, weight_decay=2e-4)     ## scheduler 1
    optimizer = Adam(model_student.parameters(), LEARNING_RATE, BETAS,  eps=OPT_EPS, weight_decay=WEIGHT_DECAY)
    if args.alternate_optimization:
        params_prob = [param for name, param in model.named_parameters() if name != "module.class_power"]
        params_power  = [param for name, param in model.named_parameters() if name == "module.class_power"]
        optimizer_prob = Adam(params_prob, LEARNING_RATE, BETAS,  eps=OPT_EPS, weight_decay=WEIGHT_DECAY)
        optimizer_power = Adam(params_power, LEARNING_RATE, BETAS,  eps=OPT_EPS, weight_decay=WEIGHT_DECAY)

    start_epoch = 1

    if args.resume:
        #Must load weights, optimizer, epoch and best value. 
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model_student.load_state_dict(checkpoint['state_dict'])
        model_teacher.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    # Initialize teacher with same weights as student. 
    copyWeightsToModelNoGrad(model_student, model_teacher)

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)                             ## scheduler 2
    if args.alternate_optimization:
        scheduler_prob = lr_scheduler.LambdaLR(optimizer_prob, lr_lambda=lambda1)                             ## scheduler 2
        scheduler_power = lr_scheduler.LambdaLR(optimizer_power, lr_lambda=lambda1)                             ## scheduler 2

    if args.visualize:
        board = Dashboard(args.port)
        writer = SummaryWriter()
        log_base_dir = writer.file_writer.get_logdir() + "/"
        print("Saving tensorboard log to: " + log_base_dir)
        total_steps_train = 0
        total_steps_val = 0
        # Figure out histogram plot indices.
        steps_hist = int(len(loader_val)/NUM_HISTOGRAMS)
        steps_img_train = int(len(loader)/(NUM_IMG_PER_EPOCH-1))
        if steps_img_train == 0:
            steps_img_train = 1
        steps_img_val = int(len(loader_val)/(NUM_IMG_PER_EPOCH-1))
        if steps_img_val == 0:
            steps_img_val = 1
        hist_bins = np.arange(-0.5, args.force_n_classes+0.5, 1.0)

    for epoch in range(start_epoch, args.num_epochs+1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        if epoch < MAX_CONSISTENCY_EPOCH:
            cur_consistency_weight = epoch / MAX_CONSISTENCY_EPOCH
        else:
            cur_consistency_weight = 1.0

        if args.no_mean_teacher:
            cur_consistency_weight = 0.0

        if args.alternate_optimization:
            if epoch % 2 == 0:
                scheduler_power.step(epoch)
            else:
                scheduler_prob.step(epoch)
        else:
            scheduler.step(epoch)    ## scheduler 2

        average_loss_student_val = 0
        average_loss_teacher_val = 0

        epoch_loss_student = []
        epoch_loss_teacher = []
        epoch_acc_student = []
        epoch_acc_teacher = []
        epoch_loss_trav_student = []
        epoch_loss_trav_teacher = []
        epoch_loss_consistency = []
        time_train = []
        time_load = []
        time_iter = [0.0]
     
        doIouTrain = args.iouTrain   
        doIouVal =  args.iouVal      


        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model_student.train()
        model_teacher.train()

        start_time = time.time()

        for step, (images1, images2, labels) in enumerate(loader):

            time_load.append(time.time() - start_time)

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images1 = images1.cuda()
                images2 = images2.cuda()
                labels = labels.cuda()

            inputs1 = Variable(images1)
            inputs2 = Variable(images2)
            targets = Variable(labels)
            if (args.force_n_classes) > 0:
                # Forced into discrete classes. 
                output_student_prob, output_student_trav, output_student_power = model_student(inputs1, only_encode=enc)
                output_teacher_prob, output_teacher_trav, output_teacher_power = model_teacher(inputs2, only_encode=enc)
                if args.alternate_optimization:
                    if epoch % 2 == 0:
                        optimizer_power.zero_grad()
                    else:
                        optimizer_prob.zero_grad()
                else:
                    optimizer.zero_grad()
                loss_student_pred = criterion(output_student_prob, targets)
                loss_teacher_pred = criterion(output_teacher_prob, targets)
                loss_consistency = criterion_consistency(output_student_prob, output_teacher_prob, cur_consistency_weight)
                acc_student = criterion_acc(output_student_prob, targets)
                acc_teacher = criterion_acc(output_teacher_prob, targets)
            else:
                # Straight regressoin
                output_student, output_student_trav = model_student(inputs1, only_encode=enc)
                output_teacher, output_teacher_trav = model_teacher(inputs2, only_encode=enc)
                optimizer.zero_grad()
                loss_student_pred = criterion(output_student, targets)
                loss_teacher_pred = criterion(output_teacher, targets)
                loss_consistency = criterion_consistency(output_student, output_teacher, cur_consistency_weight)

            # Loss independent of how scalar value is determined
            loss_student_trav = criterion_trav(output_student_trav, targets)
            loss_teacher_trav = criterion_trav(output_teacher_trav, targets)


            #print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            # Do backward pass.
            loss_student_pred.backward(retain_graph=True)
            if epoch>0 and not args.no_mean_teacher:
                loss_student_trav.backward(retain_graph=True)
                loss_consistency.backward()
            else: 
                loss_student_trav.backward()


            if args.alternate_optimization:
                if epoch % 2 == 0:
                    optimizer_power.step()
                else:
                    optimizer_prob.step()
            else:
                optimizer.step()

            # Average over first 50 epochs.
            if epoch < DISCOUNT_RATE_START_EPOCH:
                cur_discount_rate = DISCOUNT_RATE_START
            else:
                cur_discount_rate = DISCOUNT_RATE
            copyWeightsToModelWithDiscount(model_student, model_teacher, cur_discount_rate)

            # copyWeightsToModelWithDiscount(model_student, model_teacher, DISCOUNT_RATE)

            epoch_loss_student.append(loss_student_pred.data.item())
            epoch_loss_teacher.append(loss_teacher_pred.data.item())
            epoch_loss_trav_student.append(loss_student_trav.data.item())
            epoch_loss_trav_teacher.append(loss_teacher_trav.data.item())
            epoch_loss_consistency.append(loss_consistency.data.item())
            if (args.force_n_classes) > 0:
                epoch_acc_student.append(acc_student.data.item())
                epoch_acc_teacher.append(acc_teacher.data.item())
            time_train.append(time.time() - start_time)

            # if (doIouTrain):
            #     #start_time_iou = time.time()
            #     iouEvalTrain.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
            #     #print ("Time to add confusion matrix: ", time.time() - start_time_iou)      

            #print(outputs.size())
            if args.visualize and step % steps_img_train == 0:
                step_vis_no = total_steps_train + len(epoch_loss_student)

                # Figure out and compute tensor to visualize. 
                if args.force_n_classes > 0:
                    # Compute weighted power consumption
                    sum_dim = output_student_prob.dim()-3
                    # weighted_sum_output = (output_student_prob * output_student_power).sum(dim=sum_dim, keepdim=True)
                    if (isinstance(output_student_prob, list)):
                        max_prob, vis_output = getMaxProbValue(output_student_prob[0][0].cpu().data, output_student_power[0][0].cpu().data)
                        max_prob_teacher, vis_output_teacher = getMaxProbValue(output_teacher_prob[0][0].cpu().data, output_teacher_power[0][0].cpu().data)
                        writer.add_image("train/2_classes", color_transform_classes_prob(output_student_prob[0][0].cpu().data), step_vis_no)
                        writer.add_image("train/3_max_class_probability", max_prob[0][0], step_vis_no)
                        # writer.add_image("train/4_weighted_output", color_transform_output(weighted_sum_output[0][0].cpu().data), step_vis_no)
                    else:
                        max_prob, vis_output = getMaxProbValue(output_student_prob[0].cpu().data, output_student_power[0].cpu().data)
                        max_prob_teacher, vis_output_teacher = getMaxProbValue(output_teacher_prob[0].cpu().data, output_teacher_power[0].cpu().data)
                        writer.add_image("train/2_classes", color_transform_classes_prob(output_student_prob[0].cpu().data), step_vis_no)
                        writer.add_image("train/3_max_class_probability", max_prob[0], step_vis_no)
                        # writer.add_image("train/4_weighted_output", color_transform_output(weighted_sum_output[0].cpu().data), step_vis_no)
                else:
                    if (isinstance(output_teacher, list)):
                        vis_output = output_student[0][0].cpu().data
                        vis_output_teacher = output_teacher[0][0].cpu().data
                    else:
                        vis_output = output_student[0].cpu().data
                        vis_output_teacher = output_teacher[0].cpu().data

                if (isinstance(output_teacher_trav, list)):
                    trav_output = output_student_trav[0][0].cpu().data
                    trav_output_teacher = output_teacher_trav[0][0].cpu().data
                else:
                    trav_output = output_student_trav[0].cpu().data
                    trav_output_teacher = output_teacher_trav[0].cpu().data

                start_time_plot = time.time()
                image1 = inputs1[0].cpu().data
                image2 = inputs2[0].cpu().data
                # board.image(image, f'input (epoch: {epoch}, step: {step})')
                writer.add_image("train/1_input_student", image1, step_vis_no)
                writer.add_image("train/1_input_teacher", image2, step_vis_no)
                # writer.add_image("train/5_output_student", color_transform_output(vis_output), step_vis_no)
                # writer.add_image("train/5_output_teacher", color_transform_output(vis_output_teacher), step_vis_no)
                writer.add_image("train/7_output_trav_student", trav_output, step_vis_no)
                writer.add_image("train/7_output_trav_teacher", trav_output_teacher, step_vis_no)
                # board.image(color_transform_target(targets[0].cpu().data),
                #     f'target (epoch: {epoch}, step: {step})')
                writer.add_image("train/6_target", color_transform_classes(targets.cpu().data), step_vis_no)

                # Visualize graph.
                writer.add_graph(model_teacher, inputs2)

                print ("Time for visualization: ", time.time() - start_time_plot)
                

        len_epoch_loss = len(epoch_loss_student)
        for ind, val in enumerate(epoch_loss_student):
            writer.add_scalar("train/instant_loss_student", val, total_steps_train + ind)
        for ind, val in enumerate(epoch_loss_teacher):
            writer.add_scalar("train/instant_loss_teacher", val, total_steps_train + ind)
        for ind, val in enumerate(epoch_loss_trav_student):
            writer.add_scalar("train/instant_loss_trav_student", val, total_steps_train + ind)
        for ind, val in enumerate(epoch_loss_trav_teacher):
            writer.add_scalar("train/instant_loss_trav_teacher", val, total_steps_train + ind)
        for ind, val in enumerate(epoch_loss_consistency):
            writer.add_scalar("train/instant_loss_consistency", val, total_steps_train + ind)
        if (args.force_n_classes) > 0:
            for ind, val in enumerate(epoch_acc_student):
                writer.add_scalar("train/instant_acc_student", val, total_steps_train + ind)
            for ind, val in enumerate(epoch_acc_teacher):
                writer.add_scalar("train/instant_acc_teacher", val, total_steps_train + ind)
        total_steps_train += len_epoch_loss
        avg_loss_teacher = sum(epoch_loss_teacher)/len(epoch_loss_teacher)
        writer.add_scalar("train/epoch_loss_student", sum(epoch_loss_student)/len(epoch_loss_student), total_steps_train)
        writer.add_scalar("train/epoch_loss_teacher", avg_loss_teacher, total_steps_train)
        writer.add_scalar("train/epoch_loss_trav_student", sum(epoch_loss_trav_student)/len(epoch_loss_trav_student), total_steps_train)
        writer.add_scalar("train/epoch_loss_trav_teacher", sum(epoch_loss_trav_teacher)/len(epoch_loss_trav_teacher), total_steps_train)
        writer.add_scalar("train/epoch_loss_consistency", sum(epoch_loss_consistency)/len(epoch_loss_consistency), total_steps_train)
        if (args.force_n_classes) > 0:
            writer.add_scalar("train/epoch_acc_student", sum(epoch_acc_student)/len(epoch_acc_student), total_steps_train)
            writer.add_scalar("train/epoch_acc_teacher", sum(epoch_acc_teacher)/len(epoch_acc_teacher), total_steps_train)
        # Clear loss for next loss print iteration.
        # Output class power costs
        power_dict = {}
        if args.force_n_classes > 0:
            for ind, val in enumerate(output_teacher_power.squeeze()):
                power_dict[str(ind)] = val
            writer.add_scalars("params/class_cost", power_dict, total_steps_train)
        epoch_loss_student = []
        epoch_loss_teacher = []
        epoch_loss_consistency = []
        epoch_loss_trav_student = []
        epoch_loss_trav_teacher = []
        epoch_acc_student = []
        epoch_acc_teacher = []
        # Print current loss. 
        print(f'loss: {avg_loss_teacher:0.4} (epoch: {epoch}, step: {step})', 
                "// Train: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size), 
                "// Load: %.4f s" % (sum(time_load) / len(time_load) / args.batch_size),
                "// Iter: %.4f s" % (sum(time_iter) / len(time_iter) / args.batch_size))

        if step == 0:
            time_iter.clear()
        time_iter.append(time.time() - start_time)
        # Save time for image loading duration.
        start_time = time.time()

            
        average_epoch_loss_train = avg_loss_teacher   
        
        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m'
            print ("EPOCH IoU on TRAIN set: ", iouStr, "%")  

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model_student.eval()
        model_teacher.eval()
        epoch_loss_student_val = []
        epoch_loss_teacher_val = []
        epoch_acc_student_val = []
        epoch_acc_teacher_val = []
        epoch_loss_trav_student_val = []
        epoch_loss_trav_teacher_val = []
        time_val = []


        for step, (images1, images2, labels) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images1 = images1.cuda()
                images2 = images2.cuda()
                labels = labels.cuda()

            inputs1 = Variable(images1, volatile=True)    #volatile flag makes it free backward or outputs for eval
            inputs2 = Variable(images2, volatile=True)    #volatile flag makes it free backward or outputs for eval
            targets = Variable(labels, volatile=True)

            if args.force_n_classes:
                output_student_prob, output_student_trav, output_student_power = model_student(inputs1, only_encode=enc) 
                output_teacher_prob, output_teacher_trav, output_teacher_power = model_teacher(inputs2, only_encode=enc) 
                max_prob, output_student = getMaxProbValue(output_student_prob, output_student_power)
                max_prob, output_teacher = getMaxProbValue(output_teacher_prob, output_teacher_power)
                # Compute weighted power consumption
                sum_dim = output_student_prob.dim()-3
                # weighted_sum_output = (output_student_prob * output_student_power).sum(dim=sum_dim, keepdim=True)
            else:
                output_student, output_student_trav = model_student(inputs1, only_encode=enc)
                output_teacher, output_teacher_trav = model_teacher(inputs2, only_encode=enc)

            loss_student = criterion_val(output_student_prob, targets)
            loss_teacher = criterion_val(output_teacher_prob, targets)
            loss_student_trav = criterion_trav(output_student_trav, targets)
            loss_teacher_trav = criterion_trav(output_teacher_trav, targets)
            epoch_loss_student_val.append(loss_student.data.item())
            epoch_loss_teacher_val.append(loss_teacher.data.item())
            epoch_loss_trav_student_val.append(loss_student_trav.data.item())
            epoch_loss_trav_teacher_val.append(loss_teacher_trav.data.item())
            if args.force_n_classes:
                acc_student = criterion_acc(output_student_prob, targets)
                acc_teacher = criterion_acc(output_teacher_prob, targets)
                epoch_acc_student_val.append(acc_student.data.item())
                epoch_acc_teacher_val.append(acc_teacher.data.item())
            time_val.append(time.time() - start_time)


            #Add batch to calculate TP, FP and FN for iou estimation
            # if (doIouVal):
            #     #start_time_iou = time.time()
            #     iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data)
            #     #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            # Plot images
            if args.visualize and step % steps_img_val == 0:
                if (isinstance(output_teacher_trav, list)):
                    trav_output = output_student_trav[0][0].cpu().data
                    trav_output_teacher = output_teacher_trav[0][0].cpu().data
                else:
                    trav_output = output_student_trav[0].cpu().data
                    trav_output_teacher = output_teacher_trav[0].cpu().data

                step_vis_no = total_steps_val + len(epoch_loss_student_val)
                start_time_plot = time.time()
                image1 = inputs1[0].cpu().data
                image2 = inputs2[0].cpu().data
                # board.image(image, f'VAL input (epoch: {epoch}, step: {step})')
                writer.add_image("val/1_input_student", image1, step_vis_no)
                writer.add_image("val/1_input_teacher", image2, step_vis_no)
                if isinstance(output_teacher, list):   #merge gpu tensors
                    # board.image(color_transform_output(outputs[0][0].cpu().data),
                    # f'VAL output (epoch: {epoch}, step: {step})')
                    # writer.add_image("val/5_output_teacher", color_transform_output(output_teacher[0][0].cpu().data), step_vis_no)
                    # writer.add_image("val/5_output_student", color_transform_output(output_student[0][0].cpu().data), step_vis_no)
                    if args.force_n_classes > 0:
                        writer.add_image("val/2_classes", color_transform_classes_prob(output_teacher_prob[0][0].cpu().data), step_vis_no)
                        writer.add_image("val/3_max_class_probability", max_prob[0][0], step_vis_no)
                        # writer.add_image("val/4_weighted_output", color_transform_output(weighted_sum_output[0][0].cpu().data), step_vis_no)

                else:
                    # board.image(color_transform_output(outputs[0].cpu().data),
                    # f'VAL output (epoch: {epoch}, step: {step})')
                    # writer.add_image("val/5_output_teacher", color_transform_output(output_teacher[0].cpu().data), step_vis_no)
                    # writer.add_image("val/5_output_student", color_transform_output(output_student[0].cpu().data), step_vis_no)
                    if args.force_n_classes > 0:
                        writer.add_image("val/2_classes", color_transform_classes_prob(output_teacher_prob[0].cpu().data), step_vis_no)
                        writer.add_image("val/3_max_class_probability", max_prob[0], step_vis_no)
                        # writer.add_image("val/4_weighted_output", color_transform_output(weighted_sum_output[0].cpu().data), step_vis_no)
                # board.image(color_transform_target(targets[0].cpu().data),
                #     f'VAL target (epoch: {epoch}, step: {step})')
                writer.add_image("val/7_output_trav_student", trav_output, step_vis_no)
                writer.add_image("val/7_output_trav_teacher", trav_output_teacher, step_vis_no)
                writer.add_image("val/6_target", color_transform_classes(targets.cpu().data), step_vis_no)
                print ("Time to paint images: ", time.time() - start_time_plot)
            # Plot histograms
            if args.force_n_classes > 0 and args.visualize and steps_hist > 0 and step % steps_hist == 0:
                image1 = inputs1[0].cpu().data+0.5  # +0.5 to remove zero-mean normalization
                image2 = inputs2[0].cpu().data+0.5
                hist_ind = int(step / steps_hist)
                if (isinstance(output_teacher_prob, list)):
                    _, hist_array = output_teacher_prob[0][0].cpu().data.max(dim=0, keepdim=True)
                else:
                    _, hist_array = output_teacher_prob[0].cpu().data.max(dim=0, keepdim=True)

                writer.add_histogram("val/hist_"+str(hist_ind), hist_array.numpy().flatten(), total_steps_train, hist_bins)  # Use train steps so we can compare with class power plot
                if isinstance(output_teacher, list):
                    writer.add_image("val/classes_"+str(hist_ind), color_transform_classes_prob(output_teacher_prob[0][0].cpu().data), total_steps_train)
                else:
                    writer.add_image("val/classes_"+str(hist_ind), color_transform_classes_prob(output_teacher_prob[0].cpu().data), total_steps_train)

                if epoch == start_epoch:
                    writer.add_image("val/hist/input_"+str(hist_ind), image2, total_steps_train)  # Visualize image used to compute histogram
                       
        total_steps_val += len(epoch_loss_student_val)
        avg_loss_teacher_val = sum(epoch_loss_teacher_val) / len(epoch_loss_teacher_val)
        print(f'VAL loss_teacher: {avg_loss_teacher_val:0.4} (epoch: {epoch}, step: {total_steps_val})', 
                "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size))
        writer.add_scalar("val/epoch_loss_student", sum(epoch_loss_student_val) / len(epoch_loss_student_val), total_steps_val)
        writer.add_scalar("val/epoch_loss_teacher", avg_loss_teacher_val, total_steps_val)
        writer.add_scalar("val/epoch_loss_trav_student", sum(epoch_loss_trav_student_val) / len(epoch_loss_trav_student_val), total_steps_val)
        writer.add_scalar("val/epoch_loss_trav_teacher", sum(epoch_loss_trav_teacher_val) / len(epoch_loss_trav_teacher_val), total_steps_val)
        if args.force_n_classes:
            writer.add_scalar("val/epoch_acc_student", sum(epoch_acc_student_val) / len(epoch_acc_student_val), total_steps_val)
            writer.add_scalar("val/epoch_acc_teacher", sum(epoch_acc_teacher_val) / len(epoch_acc_teacher_val), total_steps_val)

        epoch_loss_student_val = []
        epoch_loss_teacher_val = []
        epoch_acc_student_val = []
        epoch_acc_teacher_val = []
        epoch_loss_trav_student_val = []
        epoch_loss_trav_teacher_val = []

        average_epoch_loss_val = avg_loss_teacher_val
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m'
            print ("EPOCH IoU on VAL set: ", iouStr, "%") 
           

        # remember best valIoU and save checkpoint
        if iouVal == 0:
            current_acc = average_epoch_loss_val
        else:
            current_acc = iouVal 
        is_best = current_acc > best_acc
        best_acc = max(current_acc, best_acc)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'    
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model_teacher),
            'state_dict': model_teacher.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model_teacher.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model_teacher.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))   
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal))           

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr ))
    
    return(model_student, model_teacher)   #return model (convenience for encoder-decoder training)
Exemple #10
0
def train(args, model):
    model.train()

    weight = torch.ones(22)
    weight[0] = 0

    loader = DataLoader(VOC12(args.datadir, input_transform, target_transform),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)

    if args.cuda:
        criterion = CrossEntropyLoss2d(weight.cuda())
    else:
        criterion = CrossEntropyLoss2d(weight)

    optimizer = Adam(model.parameters())
    if args.model.startswith('FCN'):
        optimizer = SGD(model.parameters(), 1e-4, .9, 2e-5)
    if args.model.startswith('PSP'):
        optimizer = SGD(model.parameters(), 1e-2, .9, 1e-4)
    if args.model.startswith('Seg'):
        optimizer = SGD(model.parameters(), 1e-3, .9)

    if args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(1, args.num_epochs + 1):
        epoch_loss = []

        for step, (images, labels) in enumerate(loader):
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()

            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs)

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])
            loss.backward()
            optimizer.step()

            epoch_loss.append(loss.data[0])
            if args.steps_plot > 0 and step % args.steps_plot == 0:
                image = inputs[0].cpu().data
                image[0] = image[0] * .229 + .485
                image[1] = image[1] * .224 + .456
                image[2] = image[2] * .225 + .406
                board.image(image, f('input (epoch: {epoch}, step: {step})'))
                board.image(color_transform(outputs[0].cpu().max(0)[1].data),
                            f('output (epoch: {epoch}, step: {step})'))
                board.image(color_transform(targets[0].cpu().data),
                            f('target (epoch: {epoch}, step: {step})'))
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(f('loss: {average} (epoch: {epoch}, step: {step})'))
            if args.steps_save > 0 and step % args.steps_save == 0:
                filename = f('{args.model}-{epoch:03}-{step:04}.pth')
                torch.save(model.state_dict(), filename)
                print(f('save: {filename} (epoch: {epoch}, step: {step})'))
Exemple #11
0
def train(savedir,
          model,
          dataloader_train,
          dataloader_eval,
          criterion,
          optimizer,
          args,
          enc=False):
    min_loss = float('inf')

    # use tensorboard
    writer = SummaryWriter(log_dir=savedir)
    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)
        ):  #dont add first line if it exists
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    start_epoch = 1
    if args.resume:
        #Must load weights, optimizer, epoch and best value.
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'

        assert os.path.exists(
            filenameCheckpoint
        ), "Error: resume option was used but checkpoint was not found in folder"
        checkpoint = torch.load(filenameCheckpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = checkpoint['best_acc']
        print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch']))

    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler     ## scheduler 1
    lambda1 = lambda epoch: pow(
        (1 - ((epoch - 1) / args.num_epochs)), 0.9)  ## scheduler 2
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lr_lambda=lambda1)  ## scheduler 2

    if args.visualize and args.steps_plot > 0:
        board = Dashboard(args.port)

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step(epoch)

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        if (doIouTrain):
            iouEvalTrain = iouEval(mean_and_var)

        usedLr = 0
        for param_group in optimizer.param_groups:
            print("LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])

        model.train()
        for step, (images, labels, _) in enumerate(dataloader_train):

            start_time = time.time()
            #print (labels.size())
            #print (np.unique(labels.numpy()))
            #print("labels: ", np.unique(labels[0].numpy()))
            #labels = torch.ones(4, 1, 512, 1024).long()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            #print("image: ", images.size())
            #print("labels: ", labels.size())
            inputs = Variable(images)
            targets = Variable(labels)
            outputs = model(inputs, only_encode=enc)

            # print("output: ", outputs.size()) #TODO
            # print("targets", np.unique(targets[:, 0].cpu().data.numpy()))

            optimizer.zero_grad()
            loss = criterion(outputs, targets[:, 0])

            loss.backward()
            optimizer.step()

            epoch_loss.append(loss)
            time_train.append(time.time() - start_time)

            if (doIouTrain):
                #start_time_iou = time.time()
                iouEvalTrain.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)
                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)

            #print(outputs.size())
            if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0:
                start_time_plot = time.time()
                image = inputs[0].cpu().data
                #image[0] = image[0] * .229 + .485
                #image[1] = image[1] * .224 + .456
                #image[2] = image[2] * .225 + .406
                #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy()))
                board.image(image, f'input (epoch: {epoch}, step: {step})')
                if isinstance(outputs, list):  #merge gpu tensors
                    board.image(
                        color_transform(
                            outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                else:
                    board.image(
                        color_transform(
                            outputs[0].cpu().max(0)[1].data.unsqueeze(0)),
                        f'output (epoch: {epoch}, step: {step})')
                board.image(color_transform(targets[0].cpu().data),
                            f'target (epoch: {epoch}, step: {step})')
                print("Time to paint images: ", time.time() - start_time_plot)
            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss) / len(epoch_loss)
                print(
                    f'loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_train) / len(time_train) / args.batch_size))

        average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss)
        writer.add_scalar('train_loss', average_epoch_loss_train, epoch)

        iouTrain = 0
        if (doIouTrain):
            iouTrain, iou_classes = iouEvalTrain.getIoU()
            iouStr = getColorEntry(iouTrain) + '{:0.2f}'.format(
                iouTrain * 100) + '\033[0m'
            print("EPOCH IoU on TRAIN set: ", iouStr, "%")

        #Validate on 500 val images after each epoch of training
        print("----- VALIDATING - EPOCH", epoch, "-----")
        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(mean_and_var)

        for step, (images, labels, _) in enumerate(dataloader_eval):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()
            optimizer.zero_grad()
            inputs = Variable(images)
            targets = Variable(labels)
            with torch.no_grad():
                outputs = model(inputs, only_encode=enc)

                loss = criterion(outputs, targets[:, 0])
            epoch_loss_val.append(loss.data)
            time_val.append(time.time() - start_time)

            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print(
                    f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})',
                    "// Avg time/img: %.4f s" %
                    (sum(time_val) / len(time_val) / args.batch_size))

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)
        #scheduler.step(average_epoch_loss_val, epoch)  ## scheduler 1   # update lr if needed
        writer.add_scalar('eval_loss', average_epoch_loss_val, epoch)

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(
                iouVal * 100) + '\033[0m'
            print("EPOCH IoU on VAL set: ", iouStr, "%")

        is_best = average_epoch_loss_val < min_loss
        min_loss = min(min_loss, average_epoch_loss_val)
        if enc:
            filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar'
            filenameBest = savedir + '/model_best_enc.pth.tar'
        else:
            filenameCheckpoint = savedir + '/checkpoint.pth.tar'
            filenameBest = savedir + '/model_best.pth.tar'
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'best_acc': min_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, filenameCheckpoint, filenameBest)

        #SAVE MODEL AFTER EPOCH
        if (enc):
            filename = f'{savedir}/model_encoder-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_encoder_best.pth'
        else:
            filename = f'{savedir}/model-{epoch:03}.pth'
            filenamebest = f'{savedir}/model_best.pth'
        if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0:
            torch.save(model.state_dict(), filename)
            print(f'save: {filename} (epoch: {epoch})')
        if (is_best):
            torch.save(model.state_dict(), filenamebest)
            print(f'save: {filenamebest} (epoch: {epoch})')
            if (not enc):
                with open(savedir + "/best.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))
            else:
                with open(savedir + "/best_encoder.txt", "w") as myfile:
                    myfile.write("Best epoch is %d, with Val-IoU= %.4f" %
                                 (epoch, iouVal))

        #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU)
        #Epoch		Train-loss		Test-loss	Train-IoU	Test-IoU		learningRate
        with open(automated_log_path, "a") as myfile:
            myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" %
                         (epoch, average_epoch_loss_train,
                          average_epoch_loss_val, iouTrain, iouVal, usedLr))
    writer.close()
    torch.save(model.state_dict(), f'{savedir}/weight_final.pth')
    return (model)  #return model (convenience for encoder-decoder training)
Exemple #12
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args')
    pprint(vars(args))
    print('## options')
    pprint(options)

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    options['trial'] = 0  # Hard coded / doesn't matter
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim']
    model_opt['time_scale'] = env_opt['time_scale']
    if model_opt['mode'] in ['baselinewtheta', 'phasewtheta']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']
        model_opt['theta_sz'] = env_opt['theta_sz']
    elif model_opt['mode'] in ['baseline_lowlevel', 'phase_lowlevel']:
        model_opt['theta_space_mode'] = env_opt['theta_space_mode']

    # Check asserts
    assert (model_opt['mode'] in [
        'baseline', 'phasesimple', 'phasewstate', 'baselinewtheta',
        'phasewtheta', 'baseline_lowlevel', 'phase_lowlevel', 'interpolate',
        'cyclic'
    ])
    assert (args.algo in ['a2c', 'ppo', 'acktr'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'
                             ], 'Recurrent policy is not implemented for ACKTR'

    # Set seed
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    torch.set_num_threads(1)

    # Set logging / load previous checkpoint
    logpath = args.logdir

    # Make directory, check before overwriting
    assert not os.path.isdir(
        logpath), "Give a new directory to save so we don't overwrite anything"
    os.system('mkdir -p ' + logpath)

    # Load checkpoint
    assert (os.path.isfile(args.ckpt))
    if args.cuda:
        ckpt = torch.load(args.ckpt)
    else:
        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

    # Save options and args
    with open(os.path.join(logpath, os.path.basename(args.path_opt)),
              'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    # Save git info as well
    os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
    os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
    os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))

    # Set up plotting dashboard
    dashboard = Dashboard(options,
                          vis_options,
                          logpath,
                          vis=args.vis,
                          port=args.port)

    # Create environment
    verbose = not args.no_verbose
    fixed_states = [np.zeros(20), np.zeros(20)]
    env = make_env(args.env_name, args.seed, 0, logpath, options, verbose,
                   fixed_states)
    env = DummyVecEnv([env])
    if len(env.observation_space.shape) == 1:
        ignore_mask = np.zeros(env.observation_space.shape)
        if env_opt['add_timestep']:
            ignore_mask[-1] = 1
        if model_opt['mode'] in ['baselinewtheta', 'phasewtheta']:
            theta_sz = env_opt['theta_sz']
            if env_opt['add_timestep']:
                ignore_mask[-(theta_sz + 1):] = 1
            else:
                ignore_mask[-theta_sz:] = 1
        env = ObservationFilter(env,
                                ret=False,
                                has_timestep=env_opt['add_timestep'],
                                noclip=env_opt['step_plus_noclip'],
                                ignore_mask=ignore_mask,
                                time_scale=env_opt['time_scale'],
                                gamma=env_opt['gamma'],
                                train=False)
        env.ob_rms = ckpt['ob_rms']
        raw_env = env.venv.envs[0]
    else:
        raw_env = env.envs[0]

    # Get theta_sz for models (if applicable)
    if model_opt['mode'] == 'baseline_lowlevel':
        model_opt['theta_sz'] = env.venv.envs[0].env.theta_sz
    elif model_opt['mode'] == 'phase_lowlevel':
        model_opt['theta_sz'] = env.venv.envs[0].env.env.theta_sz
    if 'theta_sz' in model_opt:
        env_opt['theta_sz'] = model_opt['theta_sz']

    # Init obs/state structures
    obs_shape = env.observation_space.shape
    obs_shape = (obs_shape[0] * env_opt['num_stack'], *obs_shape[1:])

    # Create the policy network
    actor_critic = Policy(obs_shape, env.action_space, model_opt)
    if args.cuda:
        actor_critic.cuda()

    # Load the checkpoint
    actor_critic.load_state_dict(ckpt['agent']['model'])
    if not args.cuda:
        actor_critic.base.cuda = False

    # Inline define our helper function for updating obs
    def update_current_obs(obs):
        shape_dim0 = env.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
        current_obs[:, -shape_dim0:] = obs

    # Loop through episodes
    obs = env.reset()
    assert (args.num_vid <= args.num_ep)
    episode_rewards = []
    tabbed = False
    raw_data = []
    for ep in range(args.num_ep):
        if ep < args.num_vid:
            record = True
        else:
            record = False

        # Reset env
        current_obs = torch.zeros(1, *obs_shape)
        states = torch.zeros(1, actor_critic.state_size)
        masks = torch.zeros(1, 1)
        update_current_obs(obs)

        # Complete episode
        done = False
        frames = []
        ep_total_reward = 0
        while not done:
            # Capture screenshot
            if record:
                raw_env.render()
                if not tabbed:
                    # GLFW TAB and RELEASE are hardcoded here
                    raw_env.unwrapped.viewer.key_callback(
                        None, 258, None, 0, None)
                    tabbed = True
                frames.append(
                    raw_env.unwrapped.viewer._read_pixels_as_in_window())

            # Determine action
            with torch.no_grad():
                value, action, _, states = actor_critic.act(current_obs,
                                                            states,
                                                            masks,
                                                            deterministic=True)
            cpu_actions = action.squeeze(1).cpu().numpy()

            # Add to dataset (if applicable)
            if args.dump_obsdata:
                action_cp = np.array(cpu_actions)
                raw_obs_cp = np.array(env.raw_obs)
                raw_data.append([raw_obs_cp, action_cp])

            # Observe reward and next obs
            obs, reward, done, info = env.step(cpu_actions)
            ep_total_reward += reward

            # Update obs
            masks.fill_(0.0 if done else 1.0)
            if current_obs.dim() == 4:
                current_obs *= masks.unsqueeze(2).unsqueeze(2)
            else:
                current_obs *= masks
            update_current_obs(obs)

        # Save video
        if record:
            for fr_ind, fr in enumerate(frames):
                scipy.misc.imsave(
                    os.path.join(logpath, 'tmp_fr_%d.jpg' % fr_ind), fr)
            os.system("ffmpeg -r 20 -i %s/" % logpath + "tmp_fr_%01d.jpg -y " +
                      "%s/results_ep%d.mp4" % (logpath, ep))
            os.system("rm %s/tmp_fr*.jpg" % logpath)

        # Do dashboard logging for each epsiode
        try:
            dashboard.visdom_plot()
        except IOError:
            pass

        # Print / dump reward for episode
        # DEBUG for thetas
        #print("Theta %d" % env.venv.envs[0].env.env.theta)
        print("Total reward for episode %d: %f" % (ep, ep_total_reward))
        episode_rewards.append(ep_total_reward)

    # Dump episode data to file
    if args.dump_obsdata:
        torch.save(raw_data, logpath + '/raw_episode_data.tar')

    # Print average and variance of rewards
    avg_r = np.mean(episode_rewards)
    std_r = np.std(episode_rewards)
    print("Reward over episodes was %f+-%f" % (avg_r, std_r))

    # Do dashboard logging
    try:
        dashboard.visdom_plot()
    except IOError:
        pass
Exemple #13
0
def get_eval_vals(opt,
                  vis_opt,
                  eval_key,
                  algo,
                  env_name,
                  num_trials,
                  trial_offset,
                  bin_size,
                  smooth,
                  mode='minmax'):
    # For each trial
    x_curves = []
    y_curves = []
    for trial in range(trial_offset, trial_offset + num_trials):
        # Get the logpath
        logpath = os.path.join(opt['logs']['log_base'], opt['model']['mode'],
                               opt['logs']['exp_name'], algo, env_name,
                               'trial%d' % trial)
        print(logpath)
        assert (os.path.isdir(logpath))

        # Create the dashboard object
        opt['env']['env-name'] = env_name
        opt['alg'] = opt['alg_%s' % algo]
        opt['optim'] = opt['optim_%s' % algo]
        opt['alg']['algo'] = algo
        opt['trial'] = trial
        dash = Dashboard(opt, vis_opt, logpath, vis=True)

        # Get data
        dash.preload_data()
        x, y = dash.load_data('episode_monitor', 'scalar', eval_key)
        x = [float(i) for i in x]
        y = [float(i.replace('\x00', '')) for i in y]

        # Smooth and bin
        if smooth == 1:
            x, y = dash.smooth_curve(x, y)
        elif smooth == 2:
            y = medfilt(y, kernel_size=9)
        x, y = dash.fix_point(x, y, bin_size)

        # Append
        x_curves.append(x)
        y_curves.append(y)

    # Interpolate the curves
    # Get the combined list of all x values
    union = set([])
    for x_curve in x_curves:
        union = union | set(x_curve)
    all_x = sorted(list(union))

    # Get interpolated y values of each list
    interp_y_curves = []
    for x_curve, y_curve in zip(x_curves, y_curves):
        interp_y = np.interp(all_x, x_curve, y_curve)
        interp_y_curves.append(interp_y)

    # Get mean and variance curves
    mean = np.mean(interp_y_curves, axis=0)
    y_middle = mean
    if mode == 'all':
        y_top = interp_y_curves
        y_bottom = None
    elif mode == 'minmax':
        y_bottom = np.min(interp_y_curves, axis=0)
        y_top = np.max(interp_y_curves, axis=0)
    elif mode == 'variance':
        var = np.var(interp_y_curves, axis=0)
        y_bottom = mean - var
        y_top = mean + var
    elif mode == 'std':
        std = np.std(interp_y_curves, axis=0)
        y_bottom = mean - std
        y_top = mean + std

    # Return
    return np.array(all_x), y_middle, y_top, y_bottom
Exemple #14
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.vis = not args.no_vis

    # Set options
    if args.path_opt is not None:
        with open(args.path_opt, 'r') as handle:
            options = yaml.load(handle)
    if args.vis_path_opt is not None:
        with open(args.vis_path_opt, 'r') as handle:
            vis_options = yaml.load(handle)
    print('## args'); pprint(vars(args))
    print('## options'); pprint(options)

    # Load the lowlevel opt and 
    lowlevel_optfile = options['lowlevel']['optfile']
    with open(lowlevel_optfile, 'r') as handle:
        ll_opt = yaml.load(handle)

    # Whether we should set ll policy to be deterministic or not
    ll_deterministic = options['lowlevel']['deterministic']

    # Put alg_%s and optim_%s to alg and optim depending on commandline
    options['use_cuda'] = args.cuda
    options['trial'] = 0
    options['alg'] = options['alg_%s' % args.algo]
    options['optim'] = options['optim_%s' % args.algo]
    alg_opt = options['alg']
    alg_opt['algo'] = args.algo
    model_opt = options['model']
    env_opt = options['env']
    env_opt['env-name'] = args.env_name
    log_opt = options['logs']
    optim_opt = options['optim'] 
    options['lowlevel_opt'] = ll_opt    # Save low level options in option file (for logging purposes)

    # Pass necessary values in ll_opt
    assert(ll_opt['model']['mode'] in ['baseline_lowlevel', 'phase_lowlevel'])
    ll_opt['model']['theta_space_mode'] = ll_opt['env']['theta_space_mode']
    ll_opt['model']['time_scale'] = ll_opt['env']['time_scale']

    # If in many module mode, load the lowlevel policies we want
    if model_opt['mode'] == 'hierarchical_many':
        # Check asserts
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert(theta_space_mode in ['pretrain_interp', 'pretrain_any', 'pretrain_any_far', 'pretrain_any_fromstart'])
        assert(theta_obs_mode == 'pretrain')

        # Get the theta size
        theta_sz = options['lowlevel']['num_load']
        ckpt_base = options['lowlevel']['ckpt']

        # Load checkpoints
        #lowlevel_ckpts = []
        #for ll_ind in range(theta_sz):
        #    lowlevel_ckpt_file = ckpt_base + '/trial%d/ckpt.pth.tar' % ll_ind
        #    assert(os.path.isfile(lowlevel_ckpt_file))
        #    lowlevel_ckpts.append(torch.load(lowlevel_ckpt_file))

    # Otherwise it's one ll polciy to load
    else:
        # Get theta_sz for low level model
        theta_obs_mode = ll_opt['env']['theta_obs_mode']
        theta_space_mode = ll_opt['env']['theta_space_mode']
        assert(theta_obs_mode in ['ind', 'vector'])
        if theta_obs_mode == 'ind': 
            if theta_space_mode == 'forward':
                theta_sz = 1
            elif theta_space_mode == 'simple_four':
                theta_sz = 4
            elif theta_space_mode == 'simple_eight':
                theta_sz = 8
            elif theta_space_mode == 'k_theta':
                theta_sz = ll_opt['env']['num_theta']
            elif theta_obs_mode == 'vector':
                theta_sz = 2
            else:
                raise NotImplementedError            
        else:
            raise NotImplementedError
        ll_opt['model']['theta_sz'] = theta_sz
        ll_opt['env']['theta_sz'] = theta_sz

        # Load the low level policy params
        #lowlevel_ckpt = options['lowlevel']['ckpt']            
        #assert(os.path.isfile(lowlevel_ckpt))
        #lowlevel_ckpt = torch.load(lowlevel_ckpt)
    hl_action_space = spaces.Discrete(theta_sz)

    # Check asserts
    assert(args.algo in ['a2c', 'ppo', 'acktr', 'dqn'])
    assert(optim_opt['hierarchical_mode'] in ['train_highlevel', 'train_both'])
    if model_opt['recurrent_policy']:
        assert args.algo in ['a2c', 'ppo'], 'Recurrent policy is not implemented for ACKTR'
    assert(model_opt['mode'] in ['hierarchical', 'hierarchical_many'])

    # Set seed - just make the seed the trial number
    seed = args.seed + 1000    # Make it different than lowlevel seed
    torch.manual_seed(seed)
    if args.cuda:
        torch.cuda.manual_seed(seed)

    # Initialization
    torch.set_num_threads(1)

    # Print warning
    print("#######")
    print("WARNING: All rewards are clipped or normalized so you need to use a monitor (see envs.py) or visdom plot to get true rewards")
    print("#######")

    # Set logging / load previous checkpoint
    logpath = args.logdir

    # Make directory, check before overwriting
    assert not os.path.isdir(logpath), "Give a new directory to save so we don't overwrite anything"
    os.system('mkdir -p ' + logpath)

    # Load checkpoint
    assert(os.path.isfile(args.ckpt))
    if args.cuda:
        ckpt = torch.load(args.ckpt)
    else:
        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

    # Save options and args
    with open(os.path.join(logpath, os.path.basename(args.path_opt)), 'w') as f:
        yaml.dump(options, f, default_flow_style=False)
    with open(os.path.join(logpath, 'args.yaml'), 'w') as f:
        yaml.dump(vars(args), f, default_flow_style=False)

    # Save git info as well
    os.system('git status > %s' % os.path.join(logpath, 'git_status.txt'))
    os.system('git diff > %s' % os.path.join(logpath, 'git_diff.txt'))
    os.system('git show > %s' % os.path.join(logpath, 'git_show.txt'))
    
    # Set up plotting dashboard
    dashboard = Dashboard(options, vis_options, logpath, vis=args.vis, port=args.port)

    # Create environments
    envs = [make_env(args.env_name, seed, i, logpath, options, not args.no_verbose) for i in range(1)]
    if alg_opt['num_processes'] > 1:
        envs = SubprocVecEnv(envs)
    else:
        envs = DummyVecEnv(envs)

    # Check if we use timestep in low level
    if 'baseline' in ll_opt['model']['mode']:
        add_timestep = False
    elif 'phase' in ll_opt['model']['mode']:
        add_timestep = True
    else:
        raise NotImplementedError

    # Get shapes
    dummy_env = make_env(args.env_name, seed, 0, logpath, options, not args.no_verbose) 
    dummy_env = dummy_env()
    s_pro_dummy = dummy_env.unwrapped._get_pro_obs()
    s_ext_dummy = dummy_env.unwrapped._get_ext_obs()
    if add_timestep:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz + 1,)
        ll_raw_obs_shape =(s_pro_dummy.shape[0] + 1,)
    else:
        ll_obs_shape = (s_pro_dummy.shape[0] + theta_sz,)
        ll_raw_obs_shape = (s_pro_dummy.shape[0],)
    ll_obs_shape = (ll_obs_shape[0] * env_opt['num_stack'], *ll_obs_shape[1:])
    hl_obs_shape = (s_ext_dummy.shape[0],)
    hl_obs_shape = (hl_obs_shape[0] * env_opt['num_stack'], *hl_obs_shape[1:])
   
    # Do vec normalize, but mask out what we don't want altered
    # Also freeze all of the low level obs
    ignore_mask = dummy_env.env._get_obs_mask()
    freeze_mask, _ = dummy_env.unwrapped._get_pro_ext_mask()
    freeze_mask = np.concatenate([freeze_mask, [0]])
    if ('normalize' in env_opt and not env_opt['normalize']) or args.algo == 'dqn':
        ignore_mask = 1 - freeze_mask
    if model_opt['mode'] == 'hierarchical_many':
        # Actually ignore both ignored values and the low level values
        # That filtering will happen later
        ignore_mask = (ignore_mask + freeze_mask > 0).astype(float)
        envs = ObservationFilter(envs, ret=alg_opt['norm_ret'], has_timestep=True, noclip=env_opt['step_plus_noclip'], ignore_mask=ignore_mask, freeze_mask=freeze_mask, time_scale=env_opt['time_scale'], gamma=env_opt['gamma'], train=False)
    else:
        envs = ObservationFilter(envs, ret=alg_opt['norm_ret'], has_timestep=True, noclip=env_opt['step_plus_noclip'], ignore_mask=ignore_mask, freeze_mask=freeze_mask, time_scale=env_opt['time_scale'], gamma=env_opt['gamma'], train=False)
    raw_env = envs.venv.envs[0]

    # Make our helper object for dealing with hierarchical observations
    hier_utils = HierarchyUtils(ll_obs_shape, hl_obs_shape, hl_action_space, theta_sz, add_timestep)

    # Set up algo monitoring
    alg_filename = os.path.join(logpath, 'Alg.Monitor.csv')
    alg_f = open(alg_filename, "wt")
    alg_f.write('# Alg Logging %s\n'%json.dumps({"t_start": time.time(), 'env_id' : dummy_env.spec and dummy_env.spec.id, 'mode': options['model']['mode'], 'name': options['logs']['exp_name']}))
    alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    alg_logger = csv.DictWriter(alg_f, fieldnames=alg_fields)
    alg_logger.writeheader()
    alg_f.flush()
    ll_alg_filename = os.path.join(logpath, 'AlgLL.Monitor.csv')
    ll_alg_f = open(ll_alg_filename, "wt")
    ll_alg_f.write('# Alg Logging LL %s\n'%json.dumps({"t_start": time.time(), 'env_id' : dummy_env.spec and dummy_env.spec.id, 'mode': options['model']['mode'], 'name': options['logs']['exp_name']}))
    ll_alg_fields = ['value_loss', 'action_loss', 'dist_entropy']
    ll_alg_logger = csv.DictWriter(ll_alg_f, fieldnames=ll_alg_fields)
    ll_alg_logger.writeheader()
    ll_alg_f.flush()

    # Create the policy networks
    ll_action_space = envs.action_space 
    if args.algo == 'dqn':
        model_opt['eps_start'] = optim_opt['eps_start']
        model_opt['eps_end'] = optim_opt['eps_end']
        model_opt['eps_decay'] = optim_opt['eps_decay']
        hl_policy = DQNPolicy(hl_obs_shape, hl_action_space, model_opt)
    else:
        hl_policy = Policy(hl_obs_shape, hl_action_space, model_opt)
    if model_opt['mode'] == 'hierarchical_many':
        ll_policy = ModularPolicy(ll_raw_obs_shape, ll_action_space, theta_sz, ll_opt)
    else:
        ll_policy = Policy(ll_obs_shape, ll_action_space, ll_opt['model'])
    # Load the previous ones here?
    if args.cuda:
        hl_policy.cuda()
        ll_policy.cuda()

    # Create the high level agent
    if args.algo == 'a2c':
        hl_agent = algo.A2C_ACKTR(hl_policy, alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'], lr=optim_opt['lr'],
                               eps=optim_opt['eps'], alpha=optim_opt['alpha'],
                               max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'ppo':
        hl_agent = algo.PPO(hl_policy, alg_opt['clip_param'], alg_opt['ppo_epoch'], alg_opt['num_mini_batch'],
                         alg_opt['value_loss_coef'], alg_opt['entropy_coef'], lr=optim_opt['lr'],
                         eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])
    elif args.algo == 'acktr':
        hl_agent = algo.A2C_ACKTR(hl_policy, alg_opt['value_loss_coef'],
                               alg_opt['entropy_coef'], acktr=True)
    elif args.algo == 'dqn':
        hl_agent = algo.DQN(hl_policy, env_opt['gamma'], batch_size=alg_opt['batch_size'], target_update=alg_opt['target_update'], 
                        mem_capacity=alg_opt['mem_capacity'], lr=optim_opt['lr'], eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])

    # Create the low level agent
    # If only training high level, make dummy agent (just does passthrough, doesn't change anything)
    if optim_opt['hierarchical_mode'] == 'train_highlevel':
        ll_agent = algo.Passthrough(ll_policy)
    elif optim_opt['hierarchical_mode'] == 'train_both':
        if args.algo == 'a2c':
            ll_agent = algo.A2C_ACKTR(ll_policy, alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'], lr=optim_opt['ll_lr'],
                                      eps=optim_opt['eps'], alpha=optim_opt['alpha'],
                                      max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'ppo':
            ll_agent = algo.PPO(ll_policy, alg_opt['clip_param'], alg_opt['ppo_epoch'], alg_opt['num_mini_batch'],
                                alg_opt['value_loss_coef'], alg_opt['entropy_coef'], lr=optim_opt['ll_lr'],
                                eps=optim_opt['eps'], max_grad_norm=optim_opt['max_grad_norm'])
        elif args.algo == 'acktr':
            ll_agent = algo.A2C_ACKTR(ll_policy, alg_opt['value_loss_coef'],
                                      alg_opt['entropy_coef'], acktr=True)
    else:
        raise NotImplementedError

    # Make the rollout structures 
    # Kind of dumb hack to avoid having to deal with rollouts
    hl_rollouts = RolloutStorage(10000*args.num_ep, 1, hl_obs_shape, hl_action_space, hl_policy.state_size)
    ll_rollouts = MaskingRolloutStorage(alg_opt['num_steps'], 1, ll_obs_shape, ll_action_space, ll_policy.state_size)
    hl_current_obs = torch.zeros(1, *hl_obs_shape)
    ll_current_obs = torch.zeros(1, *ll_obs_shape)

    # Helper functions to update the current obs
    def update_hl_current_obs(obs):
        shape_dim0 = hl_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            hl_current_obs[:, :-shape_dim0] = hl_current_obs[:, shape_dim0:]
        hl_current_obs[:, -shape_dim0:] = obs
    def update_ll_current_obs(obs):
        shape_dim0 = ll_obs_shape[0]
        obs = torch.from_numpy(obs).float()
        if env_opt['num_stack'] > 1:
            ll_current_obs[:, :-shape_dim0] = ll_current_obs[:, shape_dim0:]
        ll_current_obs[:, -shape_dim0:] = obs

    # Update agent with loaded checkpoint
    # This should update both the policy network and the optimizer
    ll_agent.load_state_dict(ckpt['ll_agent'])
    hl_agent.load_state_dict(ckpt['hl_agent'])

    # Set ob_rms
    envs.ob_rms = ckpt['ob_rms']
    
    # Reset our env and rollouts
    raw_obs = envs.reset()
    hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
    ll_obs = hier_utils.placeholder_theta(raw_ll_obs, step_counts)
    update_hl_current_obs(hl_obs)
    update_ll_current_obs(ll_obs)
    hl_rollouts.observations[0].copy_(hl_current_obs)
    ll_rollouts.observations[0].copy_(ll_current_obs)
    ll_rollouts.recent_obs.copy_(ll_current_obs)
    if args.cuda:
        hl_current_obs = hl_current_obs.cuda()
        ll_current_obs = ll_current_obs.cuda()
        hl_rollouts.cuda()
        ll_rollouts.cuda()

    # These variables are used to compute average rewards for all processes.
    episode_rewards = []
    tabbed = False
    raw_data = []

    # Loop through episodes
    step = 0
    for ep in range(args.num_ep):
        if ep < args.num_vid:
            record = True
        else:
            record = False

        # Complete episode
        done = False
        frames = []
        ep_total_reward = 0
        num_steps = 0
        while not done:
            # Step through high level action
            start_time = time.time()
            with torch.no_grad():
                hl_value, hl_action, hl_action_log_prob, hl_states = hl_policy.act(hl_rollouts.observations[step], hl_rollouts.states[step], hl_rollouts.masks[step], deterministic=True)
            step += 1
            hl_cpu_actions = hl_action.squeeze(1).cpu().numpy()
 
            # Get values to use for Q learning
            hl_state_dqn = hl_rollouts.observations[step]
            hl_action_dqn = hl_action

            # Update last ll observation with new theta
            for proc in range(1):
                # Update last observations in memory
                last_obs = ll_rollouts.observations[ll_rollouts.steps[proc], proc]
                if hier_utils.has_placeholder(last_obs):         
                    new_last_obs = hier_utils.update_theta(last_obs, hl_cpu_actions[proc])
                    ll_rollouts.observations[ll_rollouts.steps[proc], proc].copy_(new_last_obs)

                # Update most recent observations (not necessarily the same)
                assert(hier_utils.has_placeholder(ll_rollouts.recent_obs[proc]))
                new_last_obs = hier_utils.update_theta(ll_rollouts.recent_obs[proc], hl_cpu_actions[proc])
                ll_rollouts.recent_obs[proc].copy_(new_last_obs)
            assert(ll_rollouts.observations.max().item() < float('inf') and ll_rollouts.recent_obs.max().item() < float('inf'))

            # Given high level action, step through the low level actions
            death_step_mask = np.ones([1, 1])    # 1 means still alive, 0 means dead  
            hl_reward = torch.zeros([1, 1])
            hl_obs = [None for i in range(1)]
            for ll_step in range(optim_opt['num_ll_steps']):
                num_steps += 1
                # Capture screenshot
                if record:
                    raw_env.render()
                    if not tabbed:
                        # GLFW TAB and RELEASE are hardcoded here
                        raw_env.unwrapped.viewer.cam.distance += 5
                        raw_env.unwrapped.viewer.cam.lookat[0] += 2.5
                        #raw_env.unwrapped.viewer.cam.lookat[1] += 2.5
                        raw_env.render()
                        tabbed = True
                    frames.append(raw_env.unwrapped.viewer._read_pixels_as_in_window())

                # Sample actions
                with torch.no_grad():
                    ll_value, ll_action, ll_action_log_prob, ll_states = ll_policy.act(ll_rollouts.recent_obs, ll_rollouts.recent_s, ll_rollouts.recent_masks, deterministic=True)
                ll_cpu_actions = ll_action.squeeze(1).cpu().numpy()

                # Observe reward and next obs
                raw_obs, ll_reward, done, info = envs.step(ll_cpu_actions, death_step_mask)
                raw_hl_obs, raw_ll_obs, step_counts = hier_utils.seperate_obs(raw_obs)
                ll_obs = []
                for proc in range(alg_opt['num_processes']):
                    if (ll_step == optim_opt['num_ll_steps'] - 1) or done[proc]:
                        ll_obs.append(hier_utils.placeholder_theta(np.array([raw_ll_obs[proc]]), np.array([step_counts[proc]])))
                    else:
                        ll_obs.append(hier_utils.append_theta(np.array([raw_ll_obs[proc]]), np.array([hl_cpu_actions[proc]]), np.array([step_counts[proc]])))
                ll_obs = np.concatenate(ll_obs, 0) 
                ll_reward = torch.from_numpy(np.expand_dims(np.stack(ll_reward), 1)).float()
                hl_reward += ll_reward
                ep_total_reward += ll_reward.item()


                # Update high level observations (only take most recent obs if we haven't see a done before now and thus the value is valid)
                for proc, raw_hl in enumerate(raw_hl_obs):
                    if death_step_mask[proc].item() > 0:
                        hl_obs[proc] = np.array([raw_hl])

                # If done then clean the history of observations
                masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
                #final_rewards *= masks
                #final_rewards += (1 - masks) * episode_rewards  # TODO - actually not sure if I broke this logic, but this value is not used anywhere
                #episode_rewards *= masks

                # Done is actually a bool for eval since it's just one process
                done = done.item()
                if not done:
                    last_hl_obs = np.array(hl_obs)

                # TODO - I commented this out, which possibly breaks things if num_stack > 1. Fix later if necessary
                #if args.cuda:
                #    masks = masks.cuda()
                #if current_obs.dim() == 4:
                #    current_obs *= masks.unsqueeze(2).unsqueeze(2)
                #else:
                #    current_obs *= masks

                # Update low level observations
                update_ll_current_obs(ll_obs)

                # Update low level rollouts
                ll_rollouts.insert(ll_current_obs, ll_states, ll_action, ll_action_log_prob, ll_value, ll_reward, masks, death_step_mask) 

                # Update which ones have stepped to the end and shouldn't be updated next time in the loop
                death_step_mask *= masks    

            # Update high level rollouts
            hl_obs = np.concatenate(hl_obs, 0) 
            update_hl_current_obs(hl_obs) 
            hl_rollouts.insert(hl_current_obs, hl_states, hl_action, hl_action_log_prob, hl_value, hl_reward, masks)

            # Check if we want to update lowlevel policy
            if ll_rollouts.isfull and all([not hier_utils.has_placeholder(ll_rollouts.observations[ll_rollouts.steps[proc], proc]) for proc in range(alg_opt['num_processes'])]): 
                # Update low level policy
                assert(ll_rollouts.observations.max().item() < float('inf')) 
                ll_value_loss = 0
                ll_action_loss = 0
                ll_dist_entropy = 0
                ll_rollouts.after_update()

                # Update logger
                #alg_info = {}
                #alg_info['value_loss'] = ll_value_loss
                #alg_info['action_loss'] = ll_action_loss
                #alg_info['dist_entropy'] = ll_dist_entropy
                #ll_alg_logger.writerow(alg_info)
                #ll_alg_f.flush() 

        # Update alg monitor for high level
        #alg_info = {}
        #alg_info['value_loss'] = hl_value_loss
        #alg_info['action_loss'] = hl_action_loss
        #alg_info['dist_entropy'] = hl_dist_entropy
        #alg_logger.writerow(alg_info)
        #alg_f.flush() 

        # Save video
        if record:
            for fr_ind, fr in enumerate(frames):
                scipy.misc.imsave(os.path.join(logpath, 'tmp_fr_%d.jpg' % fr_ind), fr)
            os.system("ffmpeg -r 20 -i %s/" % logpath + "tmp_fr_%01d.jpg -y " + "%s/results_ep%d.mp4" % (logpath, ep))
            os.system("rm %s/tmp_fr*.jpg" % logpath)
            
        # Do dashboard logging for each epsiode
        try:
            dashboard.visdom_plot()
        except IOError:
            pass

        # Print / dump reward for episode
        # DEBUG for thetas
        #print("Theta %d" % env.venv.envs[0].env.env.theta)
        print("Total reward for episode %d: %f" % (ep, ep_total_reward))
        print("Episode length: %d" % num_steps)
        print(last_hl_obs)
        print("----------")
        episode_rewards.append(ep_total_reward)

    # Close logging file
    alg_f.close()
    ll_alg_f.close()