Beispiel #1
0
def visulize():
	from utils import visualize_graph
	from tensorboardX import SummaryWriter
	model = GConv(3, 5, 3)
	writer = SummaryWriter()
	visualize_graph(model, writer, input_size=(1, 12, 32, 32))
	writer.close()
Beispiel #2
0
def main(show_age_clusterization, show_contacts_tree,
         show_total_statistics, show_all_contacts,
         use_profile_spreading, verbose):

    with open('config.json', 'r') as config_json:
        config = config_json.read()
        settings = json.loads(config)

    if verbose:
        logger.setLevel(logging.INFO)

    tree = ContactsTree(settings['depth'], settings['age_params'])
    sender = tree.generate_tree()
    settings.update(dict(use_profile_spreading=use_profile_spreading))
    
    simulator = SimulationManager(sender=sender, settings=settings)
    simulator.start_simulation()
    nodes = sender.traverse()

    if show_total_statistics:
        stats = simulator.statistics()
        click.echo('\nAverage request number: {}'.format(simulator.average_request_number()))
        click.echo('\nAggregated data...')
        for answer, votes in stats.iteritems():
            click.echo(
                'Answer "{answer}" got {votes}% of votes'.format(
                    answer=answer,
                    votes=votes
                )
            )
    
    if show_contacts_tree:
        figure(1)
        visualize_graph(nodes)

    if show_age_clusterization:
        figure(2)
        hist(
            generate_age_ranges(1000, settings['age_params']['avg_age'], settings['age_params']['age_dev']),
            bins=50
        )
        pylab.show()

    if show_all_contacts:
        json_nodes = json.dumps(nodes, default=serializer, indent=4)
        click.echo('\nContacts tree:\n{}'.format(json_nodes))
Beispiel #3
0
    				download=True, transform=train_transform)
test_dataset = datasets.MNIST(root=args.dataset_dir, train=False, 
                    download=True,transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                num_workers=args.workers, pin_memory=True, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
                                num_workers=args.workers, pin_memory=True, shuffle=True)

# Load model
model = get_network_fn(args.model)
print(model)

# Try to visulize the model
try:
	visualize_graph(model, writer, input_size=(1, 1, 28, 28))
except:
	print('\nNetwork Visualization Failed! But the training procedure continue.')

# optimizer = optim.Adadelta(model.parameters(), lr=args.lr, rho=0.9, eps=1e-06, weight_decay=3e-05)
# optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=3e-05)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=3e-05)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if use_cuda else "cpu")
model = model.to(device)
criterion = criterion.to(device)

# Calculate the total parameters of the model
print('Model size: {:0.2f} million float parameters'.format(get_parameters_size(model)/1e6))
                                           num_workers=args.workers,
                                           pin_memory=True,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          num_workers=args.workers,
                                          pin_memory=True,
                                          shuffle=True)

# Load model
model = resnet18(M=args.M, method=args.method, stages=stages).to(device)
print(model)

# Try to visulize the model
try:
    visualize_graph(model, writer, input_size=(1, 3, 32, 32))
except:
    print(
        '\nNetwork Visualization Failed! But the training procedure continue.')

# Calculate the total parameters of the model
print('Model size: {:0.2f} million float parameters'.format(
    get_parameters_size(model) / 1e6))

MFilter_params = [
    param for name, param in model.named_parameters()
    if name[-8:] == 'MFilters'
]
Other_params = [
    param for name, param in model.named_parameters()
    if name[-8:] != 'MFilters'
Beispiel #5
0
def create_policy_and_execute(strategy, g, base_availability_models, base_model_variances, true_schedules, node_requests, mu, params, visualize, visualize_path):

    ## plan, execution loop
    num_requests = len(node_requests)
    requests_left_to_deliver = copy.deepcopy(node_requests)
    availability_observations = {}
    total_profit = 0.0
    total_maintenance_profit = 0.0
    delivery_history = []
    nodes_delivered = []
    img_history = []
    path_history = [params['start_node_id']]
    action_history = []
    curr_time = params['start_time']
    curr_node = params['start_node_id']
    # path_length = path_length       # FIXME
    path_visits = 0
    maintenance_reward_collected_current_plan = 0.0
    failure_penalty_current_plan = 0.0
    plan_times = []

    # # replan
    # if (strategy == 'no_temp') or (strategy == 'no_replan'):
    #     replan = False
    # else:
    #     replan = True

    # # multiple visits planned
    # if (strategy == 'observe_sampling_mult_visits') or ('observe_mult_visits'):
    #     multiple_visits = True
    # else:
    #     multiple_visits = False


    # if visualize:
        # visualize_graph(g, base_availability_models, true_schedules, availability_observations, curr_time_index, curr_node, node_requests, delivery_history, curr_time, mu)

        # img_history.append(img)

    replan = True
    end_reached = False


    while not(end_reached):

        if replan:
            if curr_time >= (params['start_time'] + params['budget']):
                end_reached = True
                break


            runtime_start = timer()
            mcts = MCTS(g, base_availability_models, availability_observations, requests_left_to_deliver, curr_node, curr_time, params['budget']-curr_time, params['max_iterations'], params['planning_horizon'], params['maintenance_reward'], params['deliver_reward'], params['mu'], params['discovery_factor'], params['distribution_node'], params['maintenance_node'], params['ensemble_method'])
            mcts.create_policy()
            plan_time = timer() - runtime_start
            plan_times.append(plan_time)
            print ("MCTS plan time: " + str(plan_time))

            path_visits = 1
            curr_state = mcts.root_node_id
            maintenance_reward_collected_current_plan = 0.0
            failure_penalty_current_plan = 0.0
            replan = False

        else:
            # follow policy
            next_step = mcts.choose_best_action(curr_state, params['min_expansions'], maintenance_reward_collected_current_plan, failure_penalty_current_plan)

            # print (next_step)

            # did not explore or not enough visits
            if next_step == None:
                replan = True
                print ('Replanning')
                continue
            else:
                
                action = next_step[0]
                visits = next_step[2]
                distances = next_step[3]

                # print ('Curr node: ' + curr_node)
                # print ('Time: ' + str(curr_time))
                # print ('Action: ' + action)
                # print ()

                if action == 'move':
                    curr_state = visits[0]
                    visit = mcts.nodes[curr_state].pose_id
                    dist = distances[0]
                    action_history.append('move_' + visit + '_' + str(curr_time))
                    curr_time += dist
                    curr_node = visit
                    path_visits += 1
                    path_history.append(visit)

                elif action == 'maintenance':
                    curr_state = visits[0]
                    visit = mcts.nodes[curr_state].pose_id
                    dist = distances[0]
                    action_history.append('maintenance_' + visit + '_' + str(curr_time))
                    curr_time += dist
                    curr_node = visit
                    path_visits += 1
                    path_history.append(visit)
                    total_maintenance_profit += params['maintenance_reward']
                    maintenance_reward_collected_current_plan += params['maintenance_reward']
                    
                elif action == 'observe':
                    visit = mcts.nodes[visits[0]].pose_id
                    curr_time_index = int(curr_time/params['time_interval'])
                    available = true_schedules[visit][curr_time_index]
                    availability_observations[visit] = [available, curr_time]
                    if bool(available):
                        curr_state = visits[0]
                        dist  = distances[0]
                    else:
                        curr_state = visits[1]
                        dist = distances[1]
                        visit = mcts.nodes[visits[1]].pose_id

                    action_history.append('observe_' + visit + '_' + str(curr_time))
                    curr_time += dist
                    curr_node = visit
                    path_visits += 1
                    path_history.append(visit)

                elif action == 'deliver':
                    visit = mcts.nodes[visits[0]].pose_id
                    deliver_time = curr_time + distances[0]
                    if deliver_time > (params['start_time'] + params['budget']):
                        end_reached = True
                        break

                    deliver_time_index = int(deliver_time/params['time_interval'])
                    available = true_schedules[visit][deliver_time_index]
                    availability_observations[visit] = [available, deliver_time]
                    if bool(available):
                        curr_state = visits[0]
                        dist  = distances[0]
                        requests_left_to_deliver.remove(visit)
                        total_profit += params['deliver_reward']
                        delivery_history.append([visit, deliver_time])
                        nodes_delivered.append(visit)
                    else:
                        curr_state = visits[1]
                        dist = distances[1]
                        visit = mcts.nodes[visits[1]].pose_id
                        failure_penalty_current_plan -= params['deliver_reward']


                    action_history.append('deliver_' + visit + '_' + str(curr_time))
                    curr_time += dist
                    curr_node = visit
                    path_visits += 1
                    path_history.append(visit)
                    

                

            if curr_time >= (params['start_time'] + params['budget']):
                end_reached = True
                break



            if visualize:
                curr_time_index = int(curr_time/params['time_interval'])
                img = visualize_graph(g, base_availability_models, true_schedules, availability_observations, curr_time_index, curr_node, node_requests, nodes_delivered, curr_time, params['mu'], strategy, params['use_gp'])
                img_history.append(img)

            # if breakout:
            #     break
 




    ratio_divisor = num_requests*params['deliver_reward'] + ((params['budget']-params['start_time']-num_requests)/params['time_interval'])*params['maintenance_reward']
    competitive_ratio = (float(total_profit) + total_maintenance_profit)/ratio_divisor
    maintenance_ratio_divisor = ((params['budget']-params['start_time'])/params['time_interval'])*params['maintenance_reward']
    maintenance_competitive_ratio = total_maintenance_profit/maintenance_ratio_divisor

    print ()
    print(strategy + " cr: " + str(competitive_ratio))

    ave_plan_time = sum(plan_times)/len(plan_times)

    if visualize:
        imageio.mimsave(visualize_path, img_history, duration=2)
        print (action_history)

    return total_profit, competitive_ratio, maintenance_competitive_ratio, path_history, ave_plan_time
Beispiel #6
0
def plan_and_execute(strategy, g, base_availability_models, base_model_variances, true_schedules, node_requests, mu, params, visualize, visualize_path):

    ## plan, execution loop
    num_requests = len(node_requests)
    requests_left_to_deliver = copy.deepcopy(node_requests)
    availability_observations = {}
    total_profit = 0.0
    total_maintenance_profit = 0.0
    delivery_history = []
    nodes_delivered = []
    img_history = []
    path_history = [params['start_node_id']]
    action_history = []
    curr_time = params['start_time']
    curr_node = params['start_node_id']
    path_length = 1
    path_visits = 0
    plan_times = []

    # replan
    if (strategy == 'no_temp') or (strategy == 'no_replan'):
        replan = False
    else:
        replan = True

    # multiple visits planned
    if (strategy == 'observe_sampling_mult_visits') or ('observe_mult_visits'):
        multiple_visits = True
    else:
        multiple_visits = False


    # if visualize:
        # visualize_graph(g, base_availability_models, true_schedules, availability_observations, curr_time_index, curr_node, node_requests, delivery_history, curr_time, mu)

        # img_history.append(img)


    while (path_visits < path_length):

        runtime_start = timer()
        path = plan_path(strategy, g, base_availability_models, base_model_variances, availability_observations, requests_left_to_deliver, curr_time, curr_node, mu, params)
        path_length = len(path)
        plan_time = timer() - runtime_start
        plan_times.append(plan_time)
        print ("Plan time: " + str(plan_time))

        ## Execute
        if path_length > 1:

            path_visits = 1
            for next_step in path:
                visit = next_step[0]
                action = next_step[1]
                dist = next_step[2]
                path_history.append(visit)
                action_history.append(action + '_' + visit + '_' + str(curr_time))
                # if curr_node == visit:
                #     dist = 1
                # else:
                #     # dist = g.get_distance(curr_node, visit)
                #     dist = g[curr_node][visit]['weight']
                curr_time += dist
                curr_node = visit
                path_visits += 1

                if (visit == params['maintenance_node']) and (action == 'maintenance'):
                    total_maintenance_profit += params['maintenance_reward']

                breakout = False
                curr_time_index = int(curr_time/params['time_interval'])
                if curr_time_index > (params['num_intervals'] - 1):
                    print("Curr time index exceeds num intervals: " + str(curr_time_index))
                    curr_time_index = params['num_intervals']-1
                    path_visits = 0
                    path_length = 0
                    break
                # assert(curr_time_index <= (params['num_intervals'] - 1))

                if visit in requests_left_to_deliver:
                    if action == 'deliver':
                        available = true_schedules[visit][curr_time_index]
                        if bool(available):
                            requests_left_to_deliver.remove(visit)
                            total_profit += params['deliver_reward']
                            delivery_history.append([visit, curr_time])
                            nodes_delivered.append(visit)
                            availability_observations[visit] = [1, curr_time]

                            if multiple_visits:
                                if replan:
                                    path_visits = 0
                                    path_length = 1
                                    breakout = True
                        else:
                            # observation
                            availability_observations[visit] = [0, curr_time]

                            # put package back
                            if (curr_time + dist/2) <= (params['start_time'] + params['budget']):
                                curr_time += dist/2
                            else:
                                curr_time = params['start_time'] + params['budget']
                            curr_node = params['start_node_id']
                            path_visits += 1
                            path_history.append(curr_node)

                            if replan:
                                path_visits = 0
                                path_length = 1
                                breakout = True

                    # Break out after every observation
                    if action == 'observe':
                        available = true_schedules[visit][curr_time_index]
                        availability_observations[visit] = [available, curr_time]
                        if replan:
                            path_visits = 0
                            path_length = 1
                            breakout = True

                if visualize:
                    img = visualize_graph(g, base_availability_models, true_schedules, availability_observations, curr_time_index, curr_node, node_requests, nodes_delivered, curr_time, mu, strategy, params['use_gp'])
                    img_history.append(img)

                if breakout:
                    break
        else:
            path_visits = 0
            path_length = 0


    ratio_divisor = num_requests*params['deliver_reward'] + ((params['budget']-params['start_time']-num_requests)/params['time_interval'])*params['maintenance_reward']
    competitive_ratio = (float(total_profit) + total_maintenance_profit)/ratio_divisor
    maintenance_ratio_divisor = ((params['budget']-params['start_time'])/params['time_interval'])*params['maintenance_reward']
    maintenance_competitive_ratio = total_maintenance_profit/maintenance_ratio_divisor

    print ()
    print(strategy + " cr: " + str(competitive_ratio))

    ave_plan_time = sum(plan_times)/len(plan_times)

    if visualize:
        imageio.mimsave(visualize_path, img_history, duration=2)
        print (action_history)

    return total_profit, competitive_ratio, maintenance_competitive_ratio, path_history, ave_plan_time
Beispiel #7
0
    def _train(self, train_loader, test_loader, output_dir, train_output_dir, test_output_dir, model_output_dir):
        args = self.args
        disc = self.disc
        disc2 = self.disc2
        gen = self.gen
        
        train_losses = {}
        test_losses = {}
        
        if args.optimizer == 'sgd':
            optimizer_g = optim.SGD(gen.parameters(), lr=args.lr/args.gen_lr_factor, momentum=0.9)
            optimizer_d = optim.SGD(disc.parameters(), lr=args.lr, momentum=0.9)
            optimizer_d2 = optim.SGD(disc2.parameters(), lr=args.lr, momentum=0.9)
        elif args.optimizer == 'adam':
            optimizer_g = optim.Adam(gen.parameters(), lr=args.lr/args.gen_lr_factor, betas=(0.5, 0.999))
            optimizer_d = optim.Adam(disc.parameters(), lr=args.lr, betas=(0.5, 0.999))
            optimizer_d2 = optim.Adam(disc2.parameters(), lr=args.lr, betas=(0.5, 0.999))
        
        for epoch in range(args.epochs_cls):
            self.disc.train()
            train_loss = self._train_cls(self.args, epoch+1, disc, train_loader, optimizer_d)
            train_losses[epoch] = train_loss
            # Frequency to output and visualize results
            if (epoch+1) % args.visualize_freq == 0 or epoch == 0:
                test_loss = self._test_2(args, epoch+1, disc, test_loader, test_output_dir)
                test_losses[epoch] = test_loss

                torch.save(disc.state_dict(), model_output_dir + 'disc_cls.pth')
                print("Saved model")

                with open(output_dir +'/train_losses.txt', 'w') as f:
                    json.dump(train_losses, f)
                with open(output_dir +'/test_losses.txt', 'w') as f:
                    json.dump(test_losses, f)
                    
#         self._build_statistics(args, train_loader, disc)
        
        for epoch in range(args.epochs_cls, args.epochs_cls+args.epochs):
            self.disc.train()
            self.gen.train()
            
            if args.joint:
                # Training discriminator and generator with same batch
                train_loss = self._train_joint(self.args, epoch+1, disc, gen, train_loader, optimizer_d, optimizer_g, train_output_dir)
            else:
                # Training discriminator for one epoch, then generator for one epoch
                train_loss = self._train_alt(self.args, epoch+1, disc, gen, train_loader, optimizer_d, optimizer_g, train_output_dir)
            train_losses[epoch] = train_loss
            
            disc.eval()
            gen.eval()
            
            # Frequency to output and visualize results
            if (epoch+1) % args.visualize_freq == 0 or epoch == args.epochs_cls:
                test_loss = self._test(args, epoch+1, disc, gen, test_loader, test_output_dir)
                test_losses[epoch] = test_loss

                torch.save(disc.state_dict(), model_output_dir + 'disc.pth')
                torch.save(gen.state_dict(), model_output_dir + 'gen.pth')
                print("Saved model")

                with open(output_dir +'/train_losses.txt', 'w') as f:
                    json.dump(train_losses, f)
                with open(output_dir +'/test_losses.txt', 'w') as f:
                    json.dump(test_losses, f)
                    
                visualize_graph(train_losses, epoch+1, output_dir)
                
        # Validate how well training on generated feature maps work on real images
        if args.phase_2:
            disc.eval()
            gen.eval()
            for epoch in range(args.epochs_cls+args.epochs, args.epochs_cls+args.epochs+args.epochs_2):
                disc2.train()
                train_loss = self._train_2(args, epoch+1, disc, disc2, gen, train_loader, optimizer_d2, train_output_dir)
                train_losses[epoch] = train_loss
                disc2.eval()
                if (epoch+1) % args.visualize_freq == 0 or epoch == args.epochs_cls+args.epochs:
                    test_loss = self._test_2(args, epoch+1, disc2, test_loader, test_output_dir)
                    test_losses[epoch] = test_loss

                    torch.save(disc2.state_dict(), model_output_dir + 'disc2.pth')
                    print("Saved model")

                    with open(output_dir +'/train_losses.txt', 'w') as f:
                        json.dump(train_losses, f)
                    with open(output_dir +'/test_losses.txt', 'w') as f:
                        json.dump(test_losses, f)
Beispiel #8
0
    def _train(self, train_loader, test_loader, output_dir, train_output_dir,
               test_output_dir, model_output_dir):
        args = self.args
        disc = self.disc
        disc2 = self.disc2
        gen = self.gen

        train_losses = {}
        test_losses = {}

        optimizer_g = optim.Adam(gen.parameters(),
                                 lr=args.lr,
                                 betas=(0.5, 0.999))
        optimizer_d = optim.Adam(disc.parameters(),
                                 lr=args.lr,
                                 betas=(0.5, 0.999))
        optimizer_d2 = optim.Adam(disc2.parameters(),
                                  lr=args.lr,
                                  betas=(0.5, 0.999))

        for epoch in range(args.epochs):
            self.disc.train()
            self.gen.train()
            if args.joint:
                train_loss = self._train_joint(self.args, epoch + 1, disc, gen,
                                               train_loader, optimizer_d,
                                               optimizer_g, train_output_dir)
            else:
                train_loss = self._train_alt(self.args, epoch + 1, disc, gen,
                                             train_loader, optimizer_d,
                                             optimizer_g, train_output_dir)
            train_losses[epoch] = train_loss
            disc.eval()
            gen.eval()
            if (epoch + 1) % args.visualize_freq == 0 or epoch == 0:
                test_loss = self._test(args, epoch + 1, disc, gen, test_loader,
                                       test_output_dir)
                test_losses[epoch] = test_loss

                torch.save(disc.state_dict(), model_output_dir + 'disc.pth')
                torch.save(gen.state_dict(), model_output_dir + 'gen.pth')
                print("Saved model")

                with open(output_dir + '/train_losses.txt', 'w') as f:
                    json.dump(train_losses, f)
                with open(output_dir + '/test_losses.txt', 'w') as f:
                    json.dump(test_losses, f)

                visualize_graph(train_losses, epoch + 1, output_dir)

        if args.phase_2:
            for epoch in range(args.epochs, args.epochs * 2):
                disc.train()
                gen.train()
                train_loss = self._train_2(args, epoch + 1, disc, disc2, gen,
                                           train_loader, optimizer_d2,
                                           train_output_dir)
                train_losses[epoch] = train_loss
                disc.eval()
                gen.eval()
                if (epoch +
                        1) % args.visualize_freq == 0 or epoch == args.epochs:
                    test_loss = self._test_2(args, epoch + 1, disc2,
                                             test_loader, test_output_dir)
                    test_losses[epoch] = test_loss

                    torch.save(disc2.state_dict(),
                               model_output_dir + 'disc2.pth')
                    print("Saved model")

                    with open(output_dir + '/train_losses.txt', 'w') as f:
                        json.dump(train_losses, f)
                    with open(output_dir + '/test_losses.txt', 'w') as f:
                        json.dump(test_losses, f)
    def _train(self, train_loader, test_loader, output_dir, train_output_dir,
               test_output_dir, model_output_dir):
        args = self.args
        disc = self.disc
        disc2 = self.disc2
        gen = self.gen

        train_losses = {}
        test_losses = {}

        if args.optimizer == 'sgd':
            optimizer_g = optim.SGD(gen.parameters(),
                                    lr=args.lr / args.gen_lr_factor,
                                    momentum=0.9)
            optimizer_d = optim.SGD(disc.parameters(),
                                    lr=args.lr,
                                    momentum=0.9)
            optimizer_d2 = optim.SGD(disc2.parameters(),
                                     lr=args.lr,
                                     momentum=0.9)
        elif args.optimizer == 'adam':
            optimizer_g = optim.Adam(gen.parameters(),
                                     lr=args.lr / args.gen_lr_factor,
                                     betas=(0.5, 0.999))
            optimizer_d = optim.Adam(disc.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
            optimizer_d2 = optim.Adam(disc2.parameters(),
                                      lr=args.lr,
                                      betas=(0.5, 0.999))

        if args.stored_features:
            # Features from custom discriminator model
            with open('data/resnet_features.pkl', 'rb') as f:
                self._features = pickle.load(f)
            with open('data/resnet_targets.pkl', 'rb') as f:
                self._targets = pickle.load(f)
            print("Loaded stored feature statistics")
        elif args.stored_features_pretrained:
            # Features from pretrained resnet model
            with open('data/pretrained_resnet_features.pkl', 'rb') as f:
                self._features = pickle.load(f)
            with open('data/pretrained_resnet_targets.pkl', 'rb') as f:
                self._targets = pickle.load(f)
            print("Loaded stored pretrained feature statistics")
        else:
            for epoch in range(args.epochs_cls):
                disc.train()
                train_loss = self._train_cls(self.args, epoch + 1, disc,
                                             train_loader, optimizer_d)
                train_losses[epoch] = train_loss
                # Frequency to output and visualize results
                if (epoch + 1) % args.visualize_freq == 0 or epoch == 0:
                    test_loss = self._test_2(args, epoch + 1, disc,
                                             test_loader, test_output_dir)
                    test_losses[epoch] = test_loss

                    torch.save(disc.state_dict(),
                               model_output_dir + 'disc_cls.pth')
                    print("Saved model")

                    with open(output_dir + '/train_losses.txt', 'w') as f:
                        json.dump(train_losses, f)
                    with open(output_dir + '/test_losses.txt', 'w') as f:
                        json.dump(test_losses, f)

            self._build_statistics_index(args, train_loader, disc)

        if args.resume:
            disc.load_state_dict(torch.load(model_output_dir + 'disc.pth'))
            gen.load_state_dict(torch.load(model_output_dir + 'gen.pth'))
            print("Loaded previous disc and gen")
        else:
            for epoch in range(args.epochs_cls, args.epochs_cls + args.epochs):
                disc.train()
                gen.train()

                if args.joint:
                    # Training discriminator and generator with same batch
                    train_loss = self._train_joint(self.args, epoch + 1, disc,
                                                   gen, train_loader,
                                                   optimizer_d, optimizer_g,
                                                   train_output_dir)
                else:
                    # Training discriminator for one epoch, then generator for one epoch
                    train_loss = self._train_alt(self.args, epoch + 1, disc,
                                                 gen, train_loader,
                                                 optimizer_d, optimizer_g,
                                                 train_output_dir)
                train_losses[epoch] = train_loss

                disc.eval()
                gen.eval()

                # Frequency to output and visualize results
                if (epoch + 1
                    ) % args.visualize_freq == 0 or epoch == args.epochs_cls:
                    test_loss = self._test(args, epoch + 1, disc, gen,
                                           test_loader, test_output_dir)
                    test_losses[epoch] = test_loss

                    torch.save(disc.state_dict(),
                               model_output_dir + 'disc.pth')
                    torch.save(gen.state_dict(), model_output_dir + 'gen.pth')
                    print("Saved model")

                    with open(output_dir + '/train_losses.txt', 'w') as f:
                        json.dump(train_losses, f)
                    with open(output_dir + '/test_losses.txt', 'w') as f:
                        json.dump(test_losses, f)

                    visualize_graph(train_losses,
                                    epoch + 1,
                                    output_dir,
                                    recon=args.recon,
                                    mse=args.mse)

        # Validate how well training on generated feature maps work on real images
        if args.phase_2:
            disc.eval()
            gen.eval()
            for epoch in range(args.epochs_cls + args.epochs,
                               args.epochs_cls + args.epochs + args.epochs_2):
                disc2.train()
                train_loss = self._train_2(args, epoch + 1, disc2, gen,
                                           train_loader, optimizer_d2,
                                           train_output_dir)
                train_losses[epoch] = train_loss
                disc2.eval()
                if (
                        epoch + 1
                ) % args.visualize_freq == 0 or epoch == args.epochs_cls + args.epochs or (
                        epoch +
                        1) == args.epochs_cls + args.epochs + args.epochs_2:
                    test_loss = self._test_2(args, epoch + 1, disc2,
                                             test_loader, test_output_dir)
                    test_losses[epoch] = test_loss

                    torch.save(disc2.state_dict(),
                               model_output_dir + 'disc2.pth')
                    print("Saved model")

                    with open(output_dir + '/train_losses.txt', 'a+') as f:
                        json.dump(train_losses, f)
                    with open(output_dir + '/test_losses.txt', 'a+') as f:
                        json.dump(test_losses, f)