[CODE] [DOCUMENTATION] [PAPER] [BLOG POST] [GOOGLE GROUP] [COLAB]

DeepArchitect: Architecture search so easy you’ll think it’s magic!

Check colab if you wish to run the examples in this post and more.

Why DeepArchitect

Designing deep learning architectures requires making many difficult low-level decisions (e.g., what operations to use and in what order to apply them). The expert is often ill-equipped to make these decisions upfront, having only high-level inductive biases (e.g., what information should be encoded).

Research in architecture search and hyperparameter optimization aims to lighten the burden of the expert by automatically finding configurations that achieve high performance (e.g., validation accuracy). Unfortunately, hyperparameter optimization systems have not been focused on supporting architecture search use-cases (e.g, complex conditional search spaces over architecture structures) and architecture search systems have not focused on programmability and modularity (e.g., existing implementations may be tied to a search space, search algorithm, or task, without being easily extensible to other search spaces, search algorithms, or tasks).

DeepArchitect is an architecture search framework that aims to improve the programmability, modularity, and reusability of architecture search methods, making them widely available to researchers and practitioners. A programmable framework of extreme importance for architecture search to be widely impactful—without it, experts are limited to use-cases for which there is available code. The framework must make it easy to encode new search spaces, write new search algorithms, and apply architecture search to new use-cases. The framework must be modular to enable experts to combine different components (e.g., different search spaces and search algorithms) without having to write each combination from scratch.

Search space constructs

Search spaces are composed of modules and hyperparameters. Modules are composed of inputs, outputs, and properties. Modules are connected among themselves through their inputs and outputs. Hyperparameters are associated to module properties. More specifically, we have the following constructs:

  • Basic modules implement concrete computation once all their properties have values. For example, a basic module for a convolution with properties num_filters and kernel_size. Basic modules are similar to fixed operations in a computational graph but with certain decisions about their local structure delayed until its hyperparameters are assigned values.

  • Substitution modules implement search space transformations that depend on the values of their properties. For example, a substitution module that creates a serial connection of two single-input single-output subgraphs whose serial order depends on the values of its property.

  • Independent hyperparameters are associated to module properties and encode a set of reasonable values for a property. Hyperparameters can be associated to multiple module properties.

  • Dependent hyperparameters have their values computed as a function of other hyperparameter values. These hyperparameters can be used to express complex value dependencies.

A search space encodes a set of possible models. Each model is reached through a transition system that progressively assigns values to independent hyperparameters until no unassigned hyperparameters are left. The resulting model can then be mapped automatically to its implementation, e.g., in a deep learning framework.

Example search space

The following code was pulled from the MobileNet Keras implementation found here.

import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, DepthwiseConv2D
from keras.layers import GlobalAvgPool2D
from keras.layers import BatchNormalization, Dropout, ReLU


def mobilenet_block(x, f, s=1):
    x = DepthwiseConv2D(3, strides=s, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(f, 1, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return x


def mobilenet(input_shape, n_classes):
    input = Input(input_shape)

    x = Conv2D(32, 3, strides=2, padding='same')(input)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = mobilenet_block(x, 64)
    x = mobilenet_block(x, 128, 2)
    x = mobilenet_block(x, 128)

    x = mobilenet_block(x, 256, 2)
    x = mobilenet_block(x, 256)

    x = mobilenet_block(x, 512, 2)
    for _ in range(5):
        x = mobilenet_block(x, 512)

    x = mobilenet_block(x, 1024, 2)
    x = mobilenet_block(x, 1024)

    x = GlobalAvgPool2D()(x)

    output = Dense(n_classes, activation='softmax')(x)

    model = Model(input, output)
    return model


input_shape = 224, 224, 3
n_classes = 1000

K.clear_session()
model = mobilenet(input_shape, n_classes)
model.summary()

Unfortunately, this implementation commits to fixed values for hyperparameters and we usually don’t know good values for them upfront. An arguably better solution would be to encode the design ambiguity in a search space.

To create a search space from the Keras implementation for a fixed model, we start by wrapping the operations in the architecture into basic modules. Single-input single-output basic modules that directly correspond to Keras function can be created by siso_keras_module_from_keras_layer_fn (e.g., relu, conv2d, depthwise_conv2d, batch_normalization, dropout, and output_layer). Properties are defined through a dictionary that maps property names to hyperparameters, e.g., in dropout, property rate gets its value from hyperparameter h_rate, which is passed as an argument to dropout. We use the prefix h_ for hyperparameters. If the prefix is not used, it refers to a fixed value (e.g, stride in conv2d is a fixed number rather than an hyperparameter).

import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, DepthwiseConv2D
from keras.layers import GlobalAvgPool2D
from keras.layers import BatchNormalization, Dropout, ReLU
import deep_architect.core as co
import deep_architect.modules as mo
from deep_architect.hyperparameters import Discrete as D
import deep_architect.helpers.keras_support as hke
import deep_architect.searchers.common as seco
import deep_architect.visualization as vi

input_shape = 224, 224, 3
n_classes = 1000


def relu():
    return hke.siso_keras_module_from_keras_layer_fn(ReLU, {})


def conv2d(h_filters, h_kernel_size, stride):
    fn = lambda filters, kernel_size, strides: Conv2D(
        filters, kernel_size, strides=strides, padding='same')
    return hke.siso_keras_module_from_keras_layer_fn(fn, {
        "filters": h_filters,
        "kernel_size": h_kernel_size,
        "strides": D([stride])
    },
                                                     name='Conv2D')


def depthwise_conv2d(h_kernel_size, stride):
    fn = lambda kernel_size, strides: DepthwiseConv2D(
        kernel_size, strides, padding='same')
    return hke.siso_keras_module_from_keras_layer_fn(fn, {
        "kernel_size": h_kernel_size,
        "strides": D([stride])
    },
                                                     name="DepthwiseConv2D")


def batch_normalization():
    return hke.siso_keras_module_from_keras_layer_fn(BatchNormalization, {})


def dropout(h_rate):
    return hke.siso_keras_module_from_keras_layer_fn(Dropout, {"rate": h_rate})


def output_layer():

    def forward_fn(x):
        x = GlobalAvgPool2D()(x)
        x = Dense(n_classes, activation='softmax')(x)
        return x

    return hke.siso_keras_module_from_keras_layer_fn(lambda: forward_fn, {},
                                                     name="OutputLayer")

We now write a search space for mobile_block to encode design ambiguity (compare it with the fixed mobile_block previously shown). Batch normalization and dropout are optionally used in two places in the block, with their use being tied (e.g., shown through the use of h_bn_opt for batch normalization). The number of filters and the kernel size are determined by hyperparameters (h_filters and h_kernel_size), and so are whether to include batch normalization or not (h_bn_opt), whether to include dropout or not (h_drop_opt), and the dropout rate used if included (h_drop_rate).

def mobile_block(h_filters, h_kernel_size, h_bn_opt, h_drop_opt, h_drop_rate,
                 stride):
    return mo.siso_sequential([
        depthwise_conv2d(h_kernel_size, stride),
        mo.siso_optional(batch_normalization, h_bn_opt),
        mo.siso_optional(lambda: dropout(h_drop_rate), h_drop_opt),
        relu(),
        conv2d(h_filters, D([1]), 1),
        mo.siso_optional(batch_normalization, h_bn_opt),
        mo.siso_optional(lambda: dropout(h_drop_rate), h_drop_opt),
        relu()
    ])

Finally, the complete search space for mobile_net is composed by the series connection of three sub-search spaces:

  • The initial section (shown in the creation of e_io) is similar to the one for the fixed MobileNet, but relaxes the number of filters (h_initial_filters) and the kernel size (h_kernel_size), along with whether to include batch normalization (h_bn_opt) or dropout (h_drop_opt, and h_drop_rate if included). The choice of dropout and batch normalization is used across the network.
  • A substitution module (block_chain) determines the number of spatial reductions (i.e., convolutions with stride greater than one; captured by h_num_reductions). The number of mobile blocks after each spatial reduction is relaxed to either be 1 or 3 (h_num_repeats). After each spatial reduction, the number of filters (h_filters, which is a dependent hyperparameter after the first reduction) is multiplied by either 1, 2, or 4 (h_filter_multiplier).
  • The output layer (global average pooling followed by dense) is the same as the fixed mobile net implementation. There are no hyperparameters in this sub-search space.
def mobile_net():
    h_drop_rate = D([0.5, 0.5**2, 0.5**3, 0.5**4])
    h_drop_opt = D([0, 1])
    h_bn_opt = D([0, 1])
    h_initial_filters = D([16, 32, 64, 128])
    h_initial_kernel_size = D([3, 5, 7])
    h_filter_multiplier = D([1, 2, 4])
    h_num_reductions = D([2, 3])
    h_kernel_size = D([3, 5])

    # main part of the model
    def block_chain(h_num_filters, h_filter_multiplier, h_num_reductions):

        def substitution_fn(dh):
            lst = []
            h_filters = h_num_filters
            for _ in range(dh["num_reductions"]):
                x = mobile_block(h_filters, h_kernel_size, h_bn_opt, h_drop_opt,
                                 h_drop_rate, 2)
                h_num_repeats = D([1, 3])
                y = mo.siso_repeat(
                        lambda: mobile_block(h_filters, h_kernel_size, h_bn_opt,
                                             h_drop_opt, h_drop_rate, 1),
                        h_num_repeats)
                lst.extend([x, y])

                h_filters = co.DependentHyperparameter(
                    lambda dh: dh["filters"] * dh["multiplier"], {
                        "filters": h_filters,
                        "multiplier": h_filter_multiplier
                    })

            return mo.siso_sequential(lst)

        return mo.substitution_module("BlockChain", substitution_fn,
                                      {"num_reductions": h_num_reductions},
                                      ["in"], ["out"])


    e_io = mo.siso_sequential([
        conv2d(h_initial_filters, h_initial_kernel_size, 2),
        mo.siso_optional(batch_normalization, h_bn_opt),
        mo.siso_optional(lambda: dropout(h_drop_rate), h_drop_opt),
        relu()
    ])

    return mo.siso_sequential([
        e_io,
        block_chain(h_initial_filters, h_filter_multiplier, h_num_reductions),
        output_layer()
    ])

Once we define a search space, it can be used to define other search spaces (e.g., the use of mobile_block in mobile_net). Only the basic modules are framework-dependent (i.e., relu, conv2d, depthwise_conv2d, batch_normalization, dropout, and output_layer). A large part of the search space (mobile_block and mobile_net) is framework-independent. DeepArchitect easily works with other domains (e.g., data augmentation strategies or other deep learning frameworks) by instantiating basic modules for these domains. Substitution modules, independent and dependent hyperparameters can be used without changes regardless of the domain. The constructs have the same semantics across domains.

We now sample a random architecture from this search space. search_space_fn returns the initial state of the search space with no independent hyperparameters with values assigned (SearchSpaceFactory does some basic wrapping). random_specify assigns values uniformly at random to each independent hyperparameter in the search space. draw_graph_evolution draws the sequence of search spaces that is obtained with each graph transition. Finally, forward creates the Keras computational graph from the terminal search space, which is finalized by the call to Model.

search_space_fn = mo.SearchSpaceFactory(mobile_net).get_search_space
inputs, outputs = search_space_fn()
# pick a random model from the search space
vs = seco.random_specify(outputs)

inputs, outputs = search_space_fn()
print vs
vi.draw_graph_evolution(outputs,
                        vs,
                        '.',
                        graph_name="graph",
                        draw_io_labels=False,
                        draw_module_hyperparameter_info=False)

# simpler graph evolution (hide the hyperparameters)
inputs, outputs = search_space_fn()
print vs
vi.draw_graph_evolution(outputs,
                        vs,
                        '.',
                        graph_name='simplified_graph',
                        draw_io_labels=False,
                        draw_hyperparameters=False)

K.clear_session()
# compilation of the Keras model
k_in = Input(input_shape)
co.forward({inputs["in"]: k_in})
k_out = outputs["out"].val
model = Model(k_in, k_out)
model.summary()

The result of model.summary() for a simple model from the search space is the following:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 112, 112, 16)      448
_________________________________________________________________
batch_normalization_1 (Batch (None, 112, 112, 16)      64
_________________________________________________________________
re_lu_1 (ReLU)               (None, 112, 112, 16)      0
_________________________________________________________________
depthwise_conv2d_1 (Depthwis (None, 56, 56, 16)        160
_________________________________________________________________
batch_normalization_2 (Batch (None, 56, 56, 16)        64
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 56, 56, 16)        272
_________________________________________________________________
batch_normalization_3 (Batch (None, 56, 56, 16)        64
_________________________________________________________________
re_lu_2 (ReLU)               (None, 56, 56, 16)        0
_________________________________________________________________
depthwise_conv2d_2 (Depthwis (None, 56, 56, 16)        160
_________________________________________________________________
batch_normalization_4 (Batch (None, 56, 56, 16)        64
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 56, 56, 16)        272
_________________________________________________________________
batch_normalization_5 (Batch (None, 56, 56, 16)        64
_________________________________________________________________
re_lu_3 (ReLU)               (None, 56, 56, 16)        0
_________________________________________________________________
depthwise_conv2d_3 (Depthwis (None, 28, 28, 16)        160
_________________________________________________________________
batch_normalization_6 (Batch (None, 28, 28, 16)        64
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 16)        272
_________________________________________________________________
batch_normalization_7 (Batch (None, 28, 28, 16)        64
_________________________________________________________________
re_lu_4 (ReLU)               (None, 28, 28, 16)        0
_________________________________________________________________
depthwise_conv2d_4 (Depthwis (None, 28, 28, 16)        160
_________________________________________________________________
batch_normalization_8 (Batch (None, 28, 28, 16)        64
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 16)        272
_________________________________________________________________
batch_normalization_9 (Batch (None, 28, 28, 16)        64
_________________________________________________________________
re_lu_5 (ReLU)               (None, 28, 28, 16)        0
_________________________________________________________________
global_average_pooling2d_1 ( (None, 16)                0
_________________________________________________________________
dense_1 (Dense)              (None, 1000)              17000
=================================================================
Total params: 19,752
Trainable params: 19,464
Non-trainable params: 288

The sequence of values assigned to independent hyperparameter that led to this model was:

[2, 0, 1, 16, 3, 2, 1, 1, 1, 1, 3, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1]

This and this PDFs show the search space transitions. The first one shows the hyperparameters along with the module properties that depend on them. The second one shows the modules and their properties (along with the values associated to them, if any). The graph expands with each transition (there are 22 value assignments leading to 23 frames, i.e., the initial frame plus one frame for each transition). The transition process starts from a search space where all independent hyperparameters are unassigned and ends with a terminal search space where all hyperparameters have been assigned values (and therefore has a single architecture that can then be mapped to a Keras implementation).

Instead of using random_specify to pick a model, search_space_fn can be passed to a search algorithm that will sequentially pick architectures from the search space. These architectures are then passed to an evaluation algorithm which returns a performance metric (e.g., validation accuracy). The architecture and the performance metric are then returned to the search algorithm, which can use this information for sampling the next architecture. This is the basic structure of an architecture search experiment.

Framework features

  • Composable search spaces: Once a search space is defined, it can be used to define other search spaces. One of the main insights of DeepArchitect is that to make outputs available, we don’t need to know their exact structure. Search spaces encoded in DeepArchitect encapsulate choices, while giving the expert a handle similar to a single computational graph.

  • Complex structural transformations through substitution modules: Substitution modules encode transformations of the model, for example, choosing between different sequences of operations, or choosing whether an operation should be used or not.

  • Delayed evaluation as a focal part of search space design: Search spaces rely on delayed evaluation. For example, computation depends on specific inputs and hyperparameters, but we may not have values for them yet. This is encoded in the search space and resolved when the search algorithm assigns values to the hyperparameters. Transformations to the structure of the network can be done through substitution modules. The specific structure is generated as a result of choices for the hyperparameters of the substitution module.

  • Modular and reusable code: Implementations are made of expressive high-level components that are loosely coupled through their APIs. For example, search algorithms only interact with search spaces through hyperparameter traversal and assignment, which can be provided for any search space. Similarly, search spaces can be constructed by arranging modules and hyperparameters. This allows us to reuse large chunks of code for new use cases.

  • Framework agnostic: The constructs and their semantics are domain-agnostic. Our language implementation provides wrappers that can be used with different deep learning frameworks (e.g., Keras, Tensorflow, and Pytorch), or even for non deep learning domains (Scikit-Learn pipelines or data augmentation).

  • Logging and visualization: Architecture search generates a wealth of data that can then be visualized for insights (e.g., what patterns lead to high-performance architectures). We provide basic visualization tools that use logs generated by architecture search experiments.

Roadmap for DeepArchitect

DeepArchitect enables modular and reusable architecture search code. Architecture search improvements making developments widely available to the community. Well-designed APIs allow us to tackle a large range of architecture search use-cases with minimal implementation effort such as:

  • Running multi-objective architecture search.
  • Searching over data augmentation strategies.
  • Constructing a database of architecture search results that can be used for future offline exploration and research.
  • Development of visualization tools to support the expert in the architecture discovery process.
  • Developing reference implementations of search spaces and search algorithms in the literature.
  • Implementing surrogate functions for predicting network performance without training.
  • Supporting distributed computation for architecture search.

Learning more about DeepArchitect

Visit the repo and documentation. The tutorials should give you a good grasp on how to use and extend DeepArchitect. It would be great to have your support to add search algorithms and search spaces to DeepArchitect, as to implement new use cases. See the paper for an in-depth description of the ideas behind the framework. Reach out to negrinho@cs.cmu.edu or @rmpnegrinho if you have questions or want to get involved. Post questions to our Google group or to Github issues. Prefer the Google group for simple usage questions and Github issues for bug fixes, feature requests, and discussions that may lead to code changes.